14
14
* limitations under the License.
15
15
*/
16
16
17
- // For now, using this to hide some functions causing problems with the build.
18
- #define HIDE_IASYNCENUMERABLE
19
-
20
17
using System ;
21
18
using System . Collections . Generic ;
19
+ using System . IO ;
22
20
using System . Linq ;
23
21
using System . Net . Http ;
24
22
using System . Text ;
@@ -45,6 +43,8 @@ public class GenerativeModel {
45
43
private readonly RequestOptions ? _requestOptions ;
46
44
47
45
private readonly HttpClient _httpClient ;
46
+ // String prefix to look for when handling streaming a response.
47
+ private const string StreamPrefix = "data: " ;
48
48
49
49
/// <summary>
50
50
/// Intended for internal use only.
@@ -107,7 +107,6 @@ public Task<GenerateContentResponse> GenerateContentAsync(
107
107
return GenerateContentAsyncInternal ( content ) ;
108
108
}
109
109
110
- #if ! HIDE_IASYNCENUMERABLE
111
110
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
112
111
params ModelContent [ ] content ) {
113
112
return GenerateContentStreamAsync ( ( IEnumerable < ModelContent > ) content ) ;
@@ -120,11 +119,6 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
120
119
IEnumerable < ModelContent > content ) {
121
120
return GenerateContentStreamAsyncInternal ( content ) ;
122
121
}
123
- public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
124
- IEnumerable < ModelContent > content ) {
125
- return GenerateContentStreamAsyncInternal ( content ) ;
126
- }
127
- #endif
128
122
129
123
public Task < CountTokensResponse > CountTokensAsync (
130
124
params ModelContent [ ] content ) {
@@ -148,17 +142,20 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
148
142
}
149
143
#endregion
150
144
145
+ private void SetRequestHeaders ( HttpRequestMessage request ) {
146
+ request . Headers . Add ( "x-goog-api-key" , _firebaseApp . Options . ApiKey ) ;
147
+ request . Headers . Add ( "x-goog-api-client" , "genai-csharp/0.1.0" ) ;
148
+ }
149
+
151
150
private async Task < GenerateContentResponse > GenerateContentAsyncInternal (
152
151
IEnumerable < ModelContent > content ) {
153
- string bodyJson = ModelContentsToJson ( content ) ;
154
-
155
152
HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":generateContent" ) ;
156
153
157
154
// Set the request headers
158
- request . Headers . Add ( "x-goog-api-key" , _firebaseApp . Options . ApiKey ) ;
159
- request . Headers . Add ( "x-goog-api-client" , "genai-csharp/0.1.0" ) ;
155
+ SetRequestHeaders ( request ) ;
160
156
161
157
// Set the content
158
+ string bodyJson = ModelContentsToJson ( content ) ;
162
159
request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
163
160
164
161
HttpResponseMessage response = await _httpClient . SendAsync ( request ) ;
@@ -169,19 +166,40 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
169
166
response . EnsureSuccessStatusCode ( ) ;
170
167
171
168
string result = await response . Content . ReadAsStringAsync ( ) ;
172
-
173
169
return GenerateContentResponse . FromJson ( result ) ;
174
170
}
175
171
176
- #if ! HIDE_IASYNCENUMERABLE
177
172
private async IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsyncInternal (
178
173
IEnumerable < ModelContent > content ) {
179
- // TODO: Implementation
180
- await Task . CompletedTask ;
181
- yield return new GenerateContentResponse ( ) ;
182
- throw new NotImplementedException ( ) ;
174
+ HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":streamGenerateContent?alt=sse" ) ;
175
+
176
+ // Set the request headers
177
+ SetRequestHeaders ( request ) ;
178
+
179
+ // Set the content
180
+ string bodyJson = ModelContentsToJson ( content ) ;
181
+ request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
182
+
183
+ HttpResponseMessage response =
184
+ await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
185
+ // TODO: Convert any timeout exception into a VertexAI equivalent
186
+ // TODO: Convert any HttpRequestExceptions, see:
187
+ // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
188
+ // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
189
+ response . EnsureSuccessStatusCode ( ) ;
190
+
191
+ // We are expecting a Stream as the response, so handle that.
192
+ using var stream = await response . Content . ReadAsStreamAsync ( ) ;
193
+ using var reader = new StreamReader ( stream ) ;
194
+
195
+ string line ;
196
+ while ( ( line = await reader . ReadLineAsync ( ) ) != null ) {
197
+ // Only pass along strings that begin with the expected prefix.
198
+ if ( line . StartsWith ( StreamPrefix ) ) {
199
+ yield return GenerateContentResponse . FromJson ( line [ StreamPrefix . Length ..] ) ;
200
+ }
201
+ }
183
202
}
184
- #endif
185
203
186
204
private async Task < CountTokensResponse > CountTokensAsyncInternal (
187
205
IEnumerable < ModelContent > content ) {
0 commit comments