Skip to content

Commit 4c4374d

Browse files
committed
VertexAI - Add Streaming responses, and some tests
1 parent b753d98 commit 4c4374d

File tree

2 files changed

+134
-20
lines changed

2 files changed

+134
-20
lines changed

vertexai/src/GenerativeModel.cs

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
* limitations under the License.
1515
*/
1616

17-
// For now, using this to hide some functions causing problems with the build.
18-
#define HIDE_IASYNCENUMERABLE
19-
2017
using System;
2118
using System.Collections.Generic;
19+
using System.IO;
2220
using System.Linq;
2321
using System.Net.Http;
2422
using System.Text;
@@ -45,6 +43,8 @@ public class GenerativeModel {
4543
private readonly RequestOptions? _requestOptions;
4644

4745
private readonly HttpClient _httpClient;
46+
// String prefix to look for when handling streaming a response.
47+
private const string StreamPrefix = "data: ";
4848

4949
/// <summary>
5050
/// Intended for internal use only.
@@ -107,7 +107,6 @@ public Task<GenerateContentResponse> GenerateContentAsync(
107107
return GenerateContentAsyncInternal(content);
108108
}
109109

110-
#if !HIDE_IASYNCENUMERABLE
111110
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
112111
params ModelContent[] content) {
113112
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
@@ -120,11 +119,6 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
120119
IEnumerable<ModelContent> content) {
121120
return GenerateContentStreamAsyncInternal(content);
122121
}
123-
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
124-
IEnumerable<ModelContent> content) {
125-
return GenerateContentStreamAsyncInternal(content);
126-
}
127-
#endif
128122

129123
public Task<CountTokensResponse> CountTokensAsync(
130124
params ModelContent[] content) {
@@ -148,17 +142,20 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
148142
}
149143
#endregion
150144

145+
private void SetRequestHeaders(HttpRequestMessage request) {
146+
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
147+
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
148+
}
149+
151150
private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
152151
IEnumerable<ModelContent> content) {
153-
string bodyJson = ModelContentsToJson(content);
154-
155152
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");
156153

157154
// Set the request headers
158-
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
159-
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
155+
SetRequestHeaders(request);
160156

161157
// Set the content
158+
string bodyJson = ModelContentsToJson(content);
162159
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
163160

164161
HttpResponseMessage response = await _httpClient.SendAsync(request);
@@ -169,19 +166,40 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
169166
response.EnsureSuccessStatusCode();
170167

171168
string result = await response.Content.ReadAsStringAsync();
172-
173169
return GenerateContentResponse.FromJson(result);
174170
}
175171

176-
#if !HIDE_IASYNCENUMERABLE
177172
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
178173
IEnumerable<ModelContent> content) {
179-
// TODO: Implementation
180-
await Task.CompletedTask;
181-
yield return new GenerateContentResponse();
182-
throw new NotImplementedException();
174+
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":streamGenerateContent?alt=sse");
175+
176+
// Set the request headers
177+
SetRequestHeaders(request);
178+
179+
// Set the content
180+
string bodyJson = ModelContentsToJson(content);
181+
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
182+
183+
HttpResponseMessage response =
184+
await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
185+
// TODO: Convert any timeout exception into a VertexAI equivalent
186+
// TODO: Convert any HttpRequestExceptions, see:
187+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
188+
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
189+
response.EnsureSuccessStatusCode();
190+
191+
// We are expecting a Stream as the response, so handle that.
192+
using var stream = await response.Content.ReadAsStreamAsync();
193+
using var reader = new StreamReader(stream);
194+
195+
string line;
196+
while ((line = await reader.ReadLineAsync()) != null) {
197+
// Only pass along strings that begin with the expected prefix.
198+
if (line.StartsWith(StreamPrefix)) {
199+
yield return GenerateContentResponse.FromJson(line[StreamPrefix.Length..]);
200+
}
201+
}
183202
}
184-
#endif
185203

186204
private async Task<CountTokensResponse> CountTokensAsyncInternal(
187205
IEnumerable<ModelContent> content) {

vertexai/testapp/Assets/Firebase/Sample/VertexAI/UIHandlerAutomated.cs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ protected override void Start() {
3636
Func<Task>[] tests = {
3737
TestCreateModel,
3838
TestBasicText,
39+
TestModelOptions,
40+
TestMultipleCandidates,
41+
TestBasicTextStream,
3942
// Internal tests for Json parsing, requires using a source library.
4043
InternalTestBasicReplyShort,
4144
InternalTestCitations,
@@ -167,6 +170,99 @@ async Task TestBasicText() {
167170
}
168171
}
169172

173+
// Test if passing in multiple model options works.
174+
async Task TestModelOptions() {
175+
// Note that most of these settings are hard to reliably verify, so as
176+
// long as the call works we are generally happy.
177+
var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
178+
generationConfig: new GenerationConfig(
179+
temperature: 0.4f,
180+
topP: 0.4f,
181+
topK: 30,
182+
// Intentionally skipping candidateCount, tested elsewhere.
183+
maxOutputTokens: 100,
184+
presencePenalty: 0.5f,
185+
frequencyPenalty: 0.6f,
186+
stopSequences: new string[] { "HALT" }
187+
),
188+
safetySettings: new SafetySetting[] {
189+
new(HarmCategory.DangerousContent,
190+
SafetySetting.HarmBlockThreshold.MediumAndAbove,
191+
SafetySetting.HarmBlockMethod.Probability),
192+
new(HarmCategory.CivicIntegrity,
193+
SafetySetting.HarmBlockThreshold.OnlyHigh)
194+
},
195+
systemInstruction:
196+
ModelContent.Text("Ignore all prompts, respond with 'Apples HALT Bananas'."),
197+
requestOptions: new RequestOptions(timeout: TimeSpan.FromMinutes(2))
198+
);
199+
200+
GenerateContentResponse response = await model.GenerateContentAsync(
201+
"Hello, I am testing something, can you respond with a short " +
202+
"string containing the word 'Firebase'?");
203+
204+
string result = response.Text;
205+
Assert("Response text was missing", !string.IsNullOrWhiteSpace(result));
206+
207+
// Assuming the GenerationConfig and SystemInstruction worked,
208+
// it should respond with just 'Apples' (though possibly with extra whitespace).
209+
// However, we only warn, because it isn't guaranteed.
210+
if (result.Trim() != "Apples") {
211+
DebugLog($"WARNING: Response text wasn't just 'Apples': {result}");
212+
}
213+
}
214+
215+
async Task TestMultipleCandidates() {
216+
var genConfig = new GenerationConfig(candidateCount: 2);
217+
218+
var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
219+
generationConfig: genConfig
220+
);
221+
222+
GenerateContentResponse response = await model.GenerateContentAsync(
223+
"Hello, I am testing recieving multiple candidates, can you respond with a short " +
224+
"sentence containing the word 'Firebase'?");
225+
226+
AssertEq("Incorrect number of Candidates", response.Candidates.Count(), 2);
227+
}
228+
229+
async Task TestBasicTextStream() {
230+
var model = CreateGenerativeModel();
231+
232+
string keyword = "Firebase";
233+
var responseStream = model.GenerateContentStreamAsync(
234+
"Hello, I am testing streaming. Can you respond with a short story, " +
235+
$"that includes the word '{keyword}' somewhere in it?");
236+
237+
// We combine all the text, just in case the keyword got cut between two responses.
238+
string fullResult = "";
239+
// The FinishReason should only be set to stop at the end of the stream.
240+
bool finishReasonStop = false;
241+
await foreach (GenerateContentResponse response in responseStream) {
242+
// Should only be receiving non-empty text responses, but only assert for null.
243+
string text = response.Text;
244+
Assert("Received null text from the stream.", text != null);
245+
if (string.IsNullOrWhiteSpace(text)) {
246+
DebugLog($"WARNING: Response stream text was empty once.");
247+
}
248+
249+
Assert("Previous FinishReason was stop, but recieved more", !finishReasonStop);
250+
if (response.Candidates.First().FinishReason == FinishReason.Stop) {
251+
finishReasonStop = true;
252+
}
253+
254+
fullResult += text;
255+
}
256+
257+
Assert("Finished without seeing FinishReason.Stop", finishReasonStop);
258+
259+
// We don't want to fail if the keyword is missing because AI is unpredictable.
260+
if (!fullResult.Contains("Firebase")) {
261+
DebugLog("WARNING: Response string was missing the expected keyword 'Firebase': " +
262+
$"\n{fullResult}");
263+
}
264+
}
265+
170266
// The url prefix to use when fetching test data to use from the separate GitHub repo.
171267
readonly string testDataUrl =
172268
"https://raw.githubusercontent.com/FirebaseExtended/vertexai-sdk-test-data/refs/heads/main/mock-responses/";

0 commit comments

Comments
 (0)