Skip to content

Commit 58e1575

Browse files
authored
VertexAI - Add Streaming responses, and some tests (#1197)
* VertexAI - Add Streaming responses, and some tests * Add a TODO about the version number * Some more comments
1 parent b753d98 commit 58e1575

File tree

2 files changed

+153
-20
lines changed

2 files changed

+153
-20
lines changed

vertexai/src/GenerativeModel.cs

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
* limitations under the License.
1515
*/
1616

17-
// For now, using this to hide some functions causing problems with the build.
18-
#define HIDE_IASYNCENUMERABLE
19-
2017
using System;
2118
using System.Collections.Generic;
19+
using System.IO;
2220
using System.Linq;
2321
using System.Net.Http;
2422
using System.Text;
@@ -45,6 +43,8 @@ public class GenerativeModel {
4543
private readonly RequestOptions? _requestOptions;
4644

4745
private readonly HttpClient _httpClient;
46+
// String prefix to look for when handling streaming a response.
47+
private const string StreamPrefix = "data: ";
4848

4949
/// <summary>
5050
/// Intended for internal use only.
@@ -107,24 +107,36 @@ public Task<GenerateContentResponse> GenerateContentAsync(
107107
return GenerateContentAsyncInternal(content);
108108
}
109109

110-
#if !HIDE_IASYNCENUMERABLE
110+
/// <summary>
111+
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
112+
/// </summary>
113+
/// <param name="content">The input(s) given to the model as a prompt.</param>
114+
/// <returns>A stream of generated content responses from the model.</returns>
115+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
111116
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
112117
params ModelContent[] content) {
113118
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
114119
}
120+
/// <summary>
121+
/// Generates new content as a stream from input text given to the model as a prompt.
122+
/// </summary>
123+
/// <param name="content">The text given to the model as a prompt.</param>
124+
/// <returns>A stream of generated content responses from the model.</returns>
125+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
115126
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
116127
string text) {
117128
return GenerateContentStreamAsync(new ModelContent[] { ModelContent.Text(text) });
118129
}
130+
/// <summary>
131+
/// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
132+
/// </summary>
133+
/// <param name="content">The input(s) given to the model as a prompt.</param>
134+
/// <returns>A stream of generated content responses from the model.</returns>
135+
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
119136
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
120137
IEnumerable<ModelContent> content) {
121138
return GenerateContentStreamAsyncInternal(content);
122139
}
123-
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
124-
IEnumerable<ModelContent> content) {
125-
return GenerateContentStreamAsyncInternal(content);
126-
}
127-
#endif
128140

129141
public Task<CountTokensResponse> CountTokensAsync(
130142
params ModelContent[] content) {
@@ -150,15 +162,13 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
150162

151163
private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
152164
IEnumerable<ModelContent> content) {
153-
string bodyJson = ModelContentsToJson(content);
154-
155165
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");
156166

157167
// Set the request headers
158-
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
159-
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
168+
SetRequestHeaders(request);
160169

161170
// Set the content
171+
string bodyJson = ModelContentsToJson(content);
162172
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
163173

164174
HttpResponseMessage response = await _httpClient.SendAsync(request);
@@ -169,19 +179,40 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
169179
response.EnsureSuccessStatusCode();
170180

171181
string result = await response.Content.ReadAsStringAsync();
172-
173182
return GenerateContentResponse.FromJson(result);
174183
}
175184

176-
#if !HIDE_IASYNCENUMERABLE
177185
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
178186
IEnumerable<ModelContent> content) {
179-
// TODO: Implementation
180-
await Task.CompletedTask;
181-
yield return new GenerateContentResponse();
182-
throw new NotImplementedException();
187+
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":streamGenerateContent?alt=sse");
188+
189+
// Set the request headers
190+
SetRequestHeaders(request);
191+
192+
// Set the content
193+
string bodyJson = ModelContentsToJson(content);
194+
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
195+
196+
HttpResponseMessage response =
197+
await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
198+
// TODO: Convert any timeout exception into a VertexAI equivalent
199+
// TODO: Convert any HttpRequestExceptions, see:
200+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
201+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
202+
response.EnsureSuccessStatusCode();
203+
204+
// We are expecting a Stream as the response, so handle that.
205+
using var stream = await response.Content.ReadAsStreamAsync();
206+
using var reader = new StreamReader(stream);
207+
208+
string line;
209+
while ((line = await reader.ReadLineAsync()) != null) {
210+
// Only pass along strings that begin with the expected prefix.
211+
if (line.StartsWith(StreamPrefix)) {
212+
yield return GenerateContentResponse.FromJson(line[StreamPrefix.Length..]);
213+
}
214+
}
183215
}
184-
#endif
185216

186217
private async Task<CountTokensResponse> CountTokensAsyncInternal(
187218
IEnumerable<ModelContent> content) {
@@ -197,6 +228,12 @@ private string GetURL() {
197228
"/publishers/google/models/" + _modelName;
198229
}
199230

231+
private void SetRequestHeaders(HttpRequestMessage request) {
232+
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
233+
// TODO: Get the Version from the Firebase.VersionInfo.SdkVersion (requires exposing it via App)
234+
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
235+
}
236+
200237
private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
201238
Dictionary<string, object> jsonDict = new() {
202239
// Convert the Contents into a list of Json dictionaries

vertexai/testapp/Assets/Firebase/Sample/VertexAI/UIHandlerAutomated.cs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ protected override void Start() {
3636
Func<Task>[] tests = {
3737
TestCreateModel,
3838
TestBasicText,
39+
TestModelOptions,
40+
TestMultipleCandidates,
41+
TestBasicTextStream,
3942
// Internal tests for Json parsing, requires using a source library.
4043
InternalTestBasicReplyShort,
4144
InternalTestCitations,
@@ -167,6 +170,99 @@ async Task TestBasicText() {
167170
}
168171
}
169172

173+
// Test if passing in multiple model options works.
174+
async Task TestModelOptions() {
175+
// Note that most of these settings are hard to reliably verify, so as
176+
// long as the call works we are generally happy.
177+
var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
178+
generationConfig: new GenerationConfig(
179+
temperature: 0.4f,
180+
topP: 0.4f,
181+
topK: 30,
182+
// Intentionally skipping candidateCount, tested elsewhere.
183+
maxOutputTokens: 100,
184+
presencePenalty: 0.5f,
185+
frequencyPenalty: 0.6f,
186+
stopSequences: new string[] { "HALT" }
187+
),
188+
safetySettings: new SafetySetting[] {
189+
new(HarmCategory.DangerousContent,
190+
SafetySetting.HarmBlockThreshold.MediumAndAbove,
191+
SafetySetting.HarmBlockMethod.Probability),
192+
new(HarmCategory.CivicIntegrity,
193+
SafetySetting.HarmBlockThreshold.OnlyHigh)
194+
},
195+
systemInstruction:
196+
ModelContent.Text("Ignore all prompts, respond with 'Apples HALT Bananas'."),
197+
requestOptions: new RequestOptions(timeout: TimeSpan.FromMinutes(2))
198+
);
199+
200+
GenerateContentResponse response = await model.GenerateContentAsync(
201+
"Hello, I am testing something, can you respond with a short " +
202+
"string containing the word 'Firebase'?");
203+
204+
string result = response.Text;
205+
Assert("Response text was missing", !string.IsNullOrWhiteSpace(result));
206+
207+
// Assuming the GenerationConfig and SystemInstruction worked,
208+
// it should respond with just 'Apples' (though possibly with extra whitespace).
209+
// However, we only warn, because it isn't guaranteed.
210+
if (result.Trim() != "Apples") {
211+
DebugLog($"WARNING: Response text wasn't just 'Apples': {result}");
212+
}
213+
}
214+
215+
async Task TestMultipleCandidates() {
216+
var genConfig = new GenerationConfig(candidateCount: 2);
217+
218+
var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
219+
generationConfig: genConfig
220+
);
221+
222+
GenerateContentResponse response = await model.GenerateContentAsync(
223+
"Hello, I am testing recieving multiple candidates, can you respond with a short " +
224+
"sentence containing the word 'Firebase'?");
225+
226+
AssertEq("Incorrect number of Candidates", response.Candidates.Count(), 2);
227+
}
228+
229+
async Task TestBasicTextStream() {
230+
var model = CreateGenerativeModel();
231+
232+
string keyword = "Firebase";
233+
var responseStream = model.GenerateContentStreamAsync(
234+
"Hello, I am testing streaming. Can you respond with a short story, " +
235+
$"that includes the word '{keyword}' somewhere in it?");
236+
237+
// We combine all the text, just in case the keyword got cut between two responses.
238+
string fullResult = "";
239+
// The FinishReason should only be set to stop at the end of the stream.
240+
bool finishReasonStop = false;
241+
await foreach (GenerateContentResponse response in responseStream) {
242+
// Should only be receiving non-empty text responses, but only assert for null.
243+
string text = response.Text;
244+
Assert("Received null text from the stream.", text != null);
245+
if (string.IsNullOrWhiteSpace(text)) {
246+
DebugLog($"WARNING: Response stream text was empty once.");
247+
}
248+
249+
Assert("Previous FinishReason was stop, but recieved more", !finishReasonStop);
250+
if (response.Candidates.First().FinishReason == FinishReason.Stop) {
251+
finishReasonStop = true;
252+
}
253+
254+
fullResult += text;
255+
}
256+
257+
Assert("Finished without seeing FinishReason.Stop", finishReasonStop);
258+
259+
// We don't want to fail if the keyword is missing because AI is unpredictable.
260+
if (!fullResult.Contains("Firebase")) {
261+
DebugLog("WARNING: Response string was missing the expected keyword 'Firebase': " +
262+
$"\n{fullResult}");
263+
}
264+
}
265+
170266
// The url prefix to use when fetching test data to use from the separate GitHub repo.
171267
readonly string testDataUrl =
172268
"https://raw.githubusercontent.com/FirebaseExtended/vertexai-sdk-test-data/refs/heads/main/mock-responses/";

0 commit comments

Comments
 (0)