@@ -91,7 +91,7 @@ public Task<GenerateContentResponse> GenerateContentAsync(
91
91
/// <summary>
92
92
/// Generates new content from input text given to the model as a prompt.
93
93
/// </summary>
94
- /// <param name="content ">The text given to the model as a prompt.</param>
94
+ /// <param name="text ">The text given to the model as a prompt.</param>
95
95
/// <returns>The generated content response from the model.</returns>
96
96
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
97
97
public Task < GenerateContentResponse > GenerateContentAsync (
@@ -122,7 +122,7 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
122
122
/// <summary>
123
123
/// Generates new content as a stream from input text given to the model as a prompt.
124
124
/// </summary>
125
- /// <param name="content ">The text given to the model as a prompt.</param>
125
+ /// <param name="text ">The text given to the model as a prompt.</param>
126
126
/// <returns>A stream of generated content responses from the model.</returns>
127
127
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128
128
public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
@@ -140,14 +140,32 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
140
140
return GenerateContentStreamAsyncInternal ( content ) ;
141
141
}
142
142
143
+ /// <summary>
144
+ /// Counts the number of tokens in a prompt using the model's tokenizer.
145
+ /// </summary>
146
+ /// <param name="content">The input(s) given to the model as a prompt.</param>
147
+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148
+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
143
149
public Task < CountTokensResponse > CountTokensAsync (
144
150
params ModelContent [ ] content ) {
145
151
return CountTokensAsync ( ( IEnumerable < ModelContent > ) content ) ;
146
152
}
153
+ /// <summary>
154
+ /// Counts the number of tokens in a prompt using the model's tokenizer.
155
+ /// </summary>
156
+ /// <param name="text">The text input given to the model as a prompt.</param>
157
+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158
+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
147
159
public Task < CountTokensResponse > CountTokensAsync (
148
160
string text ) {
149
161
return CountTokensAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
150
162
}
163
+ /// <summary>
164
+ /// Counts the number of tokens in a prompt using the model's tokenizer.
165
+ /// </summary>
166
+ /// <param name="content">The input(s) given to the model as a prompt.</param>
167
+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168
+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
151
169
public Task < CountTokensResponse > CountTokensAsync (
152
170
IEnumerable < ModelContent > content ) {
153
171
return CountTokensAsyncInternal ( content ) ;
@@ -184,12 +202,16 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
184
202
UnityEngine . Debug . Log ( "Request:\n " + bodyJson ) ;
185
203
#endif
186
204
187
- HttpResponseMessage response = await _httpClient . SendAsync ( request ) ;
188
- // TODO: Convert any timeout exception into a VertexAI equivalent
189
- // TODO: Convert any HttpRequestExceptions, see:
190
- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
191
- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
192
- response . EnsureSuccessStatusCode ( ) ;
205
+ HttpResponseMessage response ;
206
+ try {
207
+ response = await _httpClient . SendAsync ( request ) ;
208
+ response . EnsureSuccessStatusCode ( ) ;
209
+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
210
+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
211
+ } catch ( HttpRequestException e ) {
212
+ // TODO: Convert to a more precise exception when possible.
213
+ throw new VertexAIException ( "HTTP request failed." , e ) ;
214
+ }
193
215
194
216
string result = await response . Content . ReadAsStringAsync ( ) ;
195
217
@@ -215,13 +237,16 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
215
237
UnityEngine . Debug . Log ( "Request:\n " + bodyJson ) ;
216
238
#endif
217
239
218
- HttpResponseMessage response =
219
- await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
220
- // TODO: Convert any timeout exception into a VertexAI equivalent
221
- // TODO: Convert any HttpRequestExceptions, see:
222
- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
223
- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
224
- response . EnsureSuccessStatusCode ( ) ;
240
+ HttpResponseMessage response ;
241
+ try {
242
+ response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
243
+ response . EnsureSuccessStatusCode ( ) ;
244
+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
245
+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
246
+ } catch ( HttpRequestException e ) {
247
+ // TODO: Convert to a more precise exception when possible.
248
+ throw new VertexAIException ( "HTTP request failed." , e ) ;
249
+ }
225
250
226
251
// We are expecting a Stream as the response, so handle that.
227
252
using var stream = await response . Content . ReadAsStreamAsync ( ) ;
@@ -242,9 +267,37 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
242
267
243
268
private async Task < CountTokensResponse > CountTokensAsyncInternal (
244
269
IEnumerable < ModelContent > content ) {
245
- // TODO: Implementation
246
- await Task . CompletedTask ;
247
- throw new NotImplementedException ( ) ;
270
+ HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":countTokens" ) ;
271
+
272
+ // Set the request headers
273
+ SetRequestHeaders ( request ) ;
274
+
275
+ // Set the content
276
+ string bodyJson = MakeCountTokensRequest ( content ) ;
277
+ request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
278
+
279
+ #if FIREBASE_LOG_REST_CALLS
280
+ UnityEngine . Debug . Log ( "CountTokensRequest:\n " + bodyJson ) ;
281
+ #endif
282
+
283
+ HttpResponseMessage response ;
284
+ try {
285
+ response = await _httpClient . SendAsync ( request ) ;
286
+ response . EnsureSuccessStatusCode ( ) ;
287
+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
288
+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
289
+ } catch ( HttpRequestException e ) {
290
+ // TODO: Convert to a more precise exception when possible.
291
+ throw new VertexAIException ( "HTTP request failed." , e ) ;
292
+ }
293
+
294
+ string result = await response . Content . ReadAsStringAsync ( ) ;
295
+
296
+ #if FIREBASE_LOG_REST_CALLS
297
+ UnityEngine . Debug . Log ( "CountTokensResponse:\n " + result ) ;
298
+ #endif
299
+
300
+ return CountTokensResponse . FromJson ( result ) ;
248
301
}
249
302
250
303
private string GetURL ( ) {
@@ -283,6 +336,25 @@ private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
283
336
284
337
return Json . Serialize ( jsonDict ) ;
285
338
}
339
+
340
+ // CountTokensRequest is a subset of the full info needed for GenerateContent
341
+ private string MakeCountTokensRequest ( IEnumerable < ModelContent > contents ) {
342
+ Dictionary < string , object > jsonDict = new ( ) {
343
+ // Convert the Contents into a list of Json dictionaries
344
+ [ "contents" ] = contents . Select ( c => c . ToJson ( ) ) . ToList ( )
345
+ } ;
346
+ if ( _generationConfig . HasValue ) {
347
+ jsonDict [ "generationConfig" ] = _generationConfig ? . ToJson ( ) ;
348
+ }
349
+ if ( _tools != null && _tools . Length > 0 ) {
350
+ jsonDict [ "tools" ] = _tools . Select ( t => t . ToJson ( ) ) . ToList ( ) ;
351
+ }
352
+ if ( _systemInstruction . HasValue ) {
353
+ jsonDict [ "systemInstruction" ] = _systemInstruction ? . ToJson ( ) ;
354
+ }
355
+
356
+ return Json . Serialize ( jsonDict ) ;
357
+ }
286
358
}
287
359
288
360
}
0 commit comments