14
14
* limitations under the License.
15
15
*/
16
16
17
- // For now, using this to hide some functions causing problems with the build.
18
- #define HIDE_IASYNCENUMERABLE
19
-
20
17
using System ;
21
18
using System . Collections . Generic ;
19
+ using System . IO ;
22
20
using System . Linq ;
23
21
using System . Net . Http ;
24
22
using System . Text ;
@@ -45,6 +43,8 @@ public class GenerativeModel {
45
43
private readonly RequestOptions ? _requestOptions ;
46
44
47
45
private readonly HttpClient _httpClient ;
46
+ // String prefix to look for when handling streaming a response.
47
+ private const string StreamPrefix = "data: " ;
48
48
49
49
/// <summary>
50
50
/// Intended for internal use only.
@@ -107,24 +107,36 @@ public Task<GenerateContentResponse> GenerateContentAsync(
107
107
return GenerateContentAsyncInternal ( content ) ;
108
108
}
109
109
110
- #if ! HIDE_IASYNCENUMERABLE
110
+ /// <summary>
111
+ /// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
112
+ /// </summary>
113
+ /// <param name="content">The input(s) given to the model as a prompt.</param>
114
+ /// <returns>A stream of generated content responses from the model.</returns>
115
+ /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
111
116
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
112
117
params ModelContent [ ] content ) {
113
118
return GenerateContentStreamAsync ( ( IEnumerable < ModelContent > ) content ) ;
114
119
}
120
+ /// <summary>
121
+ /// Generates new content as a stream from input text given to the model as a prompt.
122
+ /// </summary>
123
+ /// <param name="content">The text given to the model as a prompt.</param>
124
+ /// <returns>A stream of generated content responses from the model.</returns>
125
+ /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
115
126
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
116
127
string text ) {
117
128
return GenerateContentStreamAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
118
129
}
130
+ /// <summary>
131
+ /// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
132
+ /// </summary>
133
+ /// <param name="content">The input(s) given to the model as a prompt.</param>
134
+ /// <returns>A stream of generated content responses from the model.</returns>
135
+ /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
119
136
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
120
137
IEnumerable < ModelContent > content ) {
121
138
return GenerateContentStreamAsyncInternal ( content ) ;
122
139
}
123
- public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
124
- IEnumerable < ModelContent > content ) {
125
- return GenerateContentStreamAsyncInternal ( content ) ;
126
- }
127
- #endif
128
140
129
141
public Task < CountTokensResponse > CountTokensAsync (
130
142
params ModelContent [ ] content ) {
@@ -150,15 +162,13 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
150
162
151
163
private async Task < GenerateContentResponse > GenerateContentAsyncInternal (
152
164
IEnumerable < ModelContent > content ) {
153
- string bodyJson = ModelContentsToJson ( content ) ;
154
-
155
165
HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":generateContent" ) ;
156
166
157
167
// Set the request headers
158
- request . Headers . Add ( "x-goog-api-key" , _firebaseApp . Options . ApiKey ) ;
159
- request . Headers . Add ( "x-goog-api-client" , "genai-csharp/0.1.0" ) ;
168
+ SetRequestHeaders ( request ) ;
160
169
161
170
// Set the content
171
+ string bodyJson = ModelContentsToJson ( content ) ;
162
172
request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
163
173
164
174
HttpResponseMessage response = await _httpClient . SendAsync ( request ) ;
@@ -169,19 +179,40 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
169
179
response . EnsureSuccessStatusCode ( ) ;
170
180
171
181
string result = await response . Content . ReadAsStringAsync ( ) ;
172
-
173
182
return GenerateContentResponse . FromJson ( result ) ;
174
183
}
175
184
176
- #if ! HIDE_IASYNCENUMERABLE
177
185
private async IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsyncInternal (
178
186
IEnumerable < ModelContent > content ) {
179
- // TODO: Implementation
180
- await Task . CompletedTask ;
181
- yield return new GenerateContentResponse ( ) ;
182
- throw new NotImplementedException ( ) ;
187
+ HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":streamGenerateContent?alt=sse" ) ;
188
+
189
+ // Set the request headers
190
+ SetRequestHeaders ( request ) ;
191
+
192
+ // Set the content
193
+ string bodyJson = ModelContentsToJson ( content ) ;
194
+ request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
195
+
196
+ HttpResponseMessage response =
197
+ await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
198
+ // TODO: Convert any timeout exception into a VertexAI equivalent
199
+ // TODO: Convert any HttpRequestExceptions, see:
200
+ // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
201
+ // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
202
+ response . EnsureSuccessStatusCode ( ) ;
203
+
204
+ // We are expecting a Stream as the response, so handle that.
205
+ using var stream = await response . Content . ReadAsStreamAsync ( ) ;
206
+ using var reader = new StreamReader ( stream ) ;
207
+
208
+ string line ;
209
+ while ( ( line = await reader . ReadLineAsync ( ) ) != null ) {
210
+ // Only pass along strings that begin with the expected prefix.
211
+ if ( line . StartsWith ( StreamPrefix ) ) {
212
+ yield return GenerateContentResponse . FromJson ( line [ StreamPrefix . Length ..] ) ;
213
+ }
214
+ }
183
215
}
184
- #endif
185
216
186
217
private async Task < CountTokensResponse > CountTokensAsyncInternal (
187
218
IEnumerable < ModelContent > content ) {
@@ -197,6 +228,12 @@ private string GetURL() {
197
228
"/publishers/google/models/" + _modelName ;
198
229
}
199
230
231
+ private void SetRequestHeaders ( HttpRequestMessage request ) {
232
+ request . Headers . Add ( "x-goog-api-key" , _firebaseApp . Options . ApiKey ) ;
233
+ // TODO: Get the Version from the Firebase.VersionInfo.SdkVersion (requires exposing it via App)
234
+ request . Headers . Add ( "x-goog-api-client" , "genai-csharp/0.1.0" ) ;
235
+ }
236
+
200
237
private string ModelContentsToJson ( IEnumerable < ModelContent > contents ) {
201
238
Dictionary < string , object > jsonDict = new ( ) {
202
239
// Convert the Contents into a list of Json dictionaries
0 commit comments