Skip to content

Commit 1c50663

Browse files
authored
VertexAI - Add CountTokens and ModalityTokenCount (#1216)
* VertexAI - Add CountToken and ModalityTokenCount * VertexAI - Improve error handling
1 parent 88cb401 commit 1c50663

File tree

7 files changed

+327
-22
lines changed

7 files changed

+327
-22
lines changed

vertexai/src/CountTokensResponse.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,64 @@
1414
* limitations under the License.
1515
*/
1616

17+
using System.Collections.Generic;
18+
using System.Collections.ObjectModel;
19+
using Google.MiniJSON;
20+
using Firebase.VertexAI.Internal;
21+
1722
namespace Firebase.VertexAI {
1823

24+
/// <summary>
25+
/// The model's response to a count tokens request.
26+
/// </summary>
1927
public readonly struct CountTokensResponse {
28+
/// <summary>
29+
/// The total number of tokens in the input given to the model as a prompt.
30+
/// </summary>
2031
public int TotalTokens { get; }
32+
/// <summary>
33+
/// The total number of billable characters in the text input given to the model as a prompt.
34+
///
35+
/// > Important: This does not include billable image, video or other non-text input. See
36+
/// [Vertex AI pricing](https://firebase.google.com/docs/vertex-ai/pricing) for details.
37+
/// </summary>
2138
public int? TotalBillableCharacters { get; }
2239

40+
private readonly ReadOnlyCollection<ModalityTokenCount> _promptTokensDetails;
41+
/// <summary>
42+
/// The breakdown, by modality, of how many tokens are consumed by the prompt.
43+
/// </summary>
44+
public IEnumerable<ModalityTokenCount> PromptTokensDetails =>
45+
_promptTokensDetails ?? new ReadOnlyCollection<ModalityTokenCount>(new List<ModalityTokenCount>());
46+
2347
// Hidden constructor, users don't need to make this
48+
private CountTokensResponse(int totalTokens,
49+
int? totalBillableCharacters = null,
50+
List<ModalityTokenCount> promptTokensDetails = null) {
51+
TotalTokens = totalTokens;
52+
TotalBillableCharacters = totalBillableCharacters;
53+
_promptTokensDetails =
54+
new ReadOnlyCollection<ModalityTokenCount>(promptTokensDetails ?? new List<ModalityTokenCount>());
55+
}
56+
57+
/// <summary>
58+
/// Intended for internal use only.
59+
/// This method is used for deserializing JSON responses and should not be called directly.
60+
/// </summary>
61+
internal static CountTokensResponse FromJson(string jsonString) {
62+
return FromJson(Json.Deserialize(jsonString) as Dictionary<string, object>);
63+
}
64+
65+
/// <summary>
66+
/// Intended for internal use only.
67+
/// This method is used for deserializing JSON responses and should not be called directly.
68+
/// </summary>
69+
internal static CountTokensResponse FromJson(Dictionary<string, object> jsonDict) {
70+
return new CountTokensResponse(
71+
jsonDict.ParseValue<int>("totalTokens"),
72+
jsonDict.ParseNullableValue<int>("totalBillableCharacters"),
73+
jsonDict.ParseObjectList("promptTokensDetails", ModalityTokenCount.FromJson));
74+
}
2475
}
2576

2677
}

vertexai/src/GenerateContentResponse.cs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,25 @@ public readonly struct UsageMetadata {
189189
/// </summary>
190190
public int TotalTokenCount { get; }
191191

192-
// TODO: New fields about ModalityTokenCount
192+
private readonly ReadOnlyCollection<ModalityTokenCount> _promptTokensDetails;
193+
public IEnumerable<ModalityTokenCount> PromptTokensDetails =>
194+
_promptTokensDetails ?? new ReadOnlyCollection<ModalityTokenCount>(new List<ModalityTokenCount>());
195+
196+
private readonly ReadOnlyCollection<ModalityTokenCount> _candidatesTokensDetails;
197+
public IEnumerable<ModalityTokenCount> CandidatesTokensDetails =>
198+
_candidatesTokensDetails ?? new ReadOnlyCollection<ModalityTokenCount>(new List<ModalityTokenCount>());
193199

194200
// Hidden constructor, users don't need to make this.
195-
private UsageMetadata(int promptTC, int candidatesTC, int totalTC) {
201+
private UsageMetadata(int promptTC, int candidatesTC, int totalTC,
202+
List<ModalityTokenCount> promptDetails, List<ModalityTokenCount> candidateDetails) {
196203
PromptTokenCount = promptTC;
197204
CandidatesTokenCount = candidatesTC;
198205
TotalTokenCount = totalTC;
206+
_promptTokensDetails =
207+
new ReadOnlyCollection<ModalityTokenCount>(promptDetails ?? new List<ModalityTokenCount>());
208+
_candidatesTokensDetails =
209+
new ReadOnlyCollection<ModalityTokenCount>(candidateDetails ?? new List<ModalityTokenCount>());
210+
199211
}
200212

201213
/// <summary>
@@ -206,7 +218,9 @@ internal static UsageMetadata FromJson(Dictionary<string, object> jsonDict) {
206218
return new UsageMetadata(
207219
jsonDict.ParseValue<int>("promptTokenCount"),
208220
jsonDict.ParseValue<int>("candidatesTokenCount"),
209-
jsonDict.ParseValue<int>("totalTokenCount"));
221+
jsonDict.ParseValue<int>("totalTokenCount"),
222+
jsonDict.ParseObjectList("promptTokensDetails", ModalityTokenCount.FromJson),
223+
jsonDict.ParseObjectList("candidatesTokensDetails", ModalityTokenCount.FromJson));
210224
}
211225
}
212226

vertexai/src/GenerativeModel.cs

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public Task<GenerateContentResponse> GenerateContentAsync(
9191
/// <summary>
9292
/// Generates new content from input text given to the model as a prompt.
9393
/// </summary>
94-
/// <param name="content">The text given to the model as a prompt.</param>
94+
/// <param name="text">The text given to the model as a prompt.</param>
9595
/// <returns>The generated content response from the model.</returns>
9696
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
9797
public Task<GenerateContentResponse> GenerateContentAsync(
@@ -122,7 +122,7 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
122122
/// <summary>
123123
/// Generates new content as a stream from input text given to the model as a prompt.
124124
/// </summary>
125-
/// <param name="content">The text given to the model as a prompt.</param>
125+
/// <param name="text">The text given to the model as a prompt.</param>
126126
/// <returns>A stream of generated content responses from the model.</returns>
127127
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128128
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
@@ -140,14 +140,32 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
140140
return GenerateContentStreamAsyncInternal(content);
141141
}
142142

143+
/// <summary>
144+
/// Counts the number of tokens in a prompt using the model's tokenizer.
145+
/// </summary>
146+
/// <param name="content">The input(s) given to the model as a prompt.</param>
147+
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148+
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
143149
public Task<CountTokensResponse> CountTokensAsync(
144150
params ModelContent[] content) {
145151
return CountTokensAsync((IEnumerable<ModelContent>)content);
146152
}
153+
/// <summary>
154+
/// Counts the number of tokens in a prompt using the model's tokenizer.
155+
/// </summary>
156+
/// <param name="text">The text input given to the model as a prompt.</param>
157+
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158+
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
147159
public Task<CountTokensResponse> CountTokensAsync(
148160
string text) {
149161
return CountTokensAsync(new ModelContent[] { ModelContent.Text(text) });
150162
}
163+
/// <summary>
164+
/// Counts the number of tokens in a prompt using the model's tokenizer.
165+
/// </summary>
166+
/// <param name="content">The input(s) given to the model as a prompt.</param>
167+
/// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168+
/// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
151169
public Task<CountTokensResponse> CountTokensAsync(
152170
IEnumerable<ModelContent> content) {
153171
return CountTokensAsyncInternal(content);
@@ -184,12 +202,16 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
184202
UnityEngine.Debug.Log("Request:\n" + bodyJson);
185203
#endif
186204

187-
HttpResponseMessage response = await _httpClient.SendAsync(request);
188-
// TODO: Convert any timeout exception into a VertexAI equivalent
189-
// TODO: Convert any HttpRequestExceptions, see:
190-
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
191-
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
192-
response.EnsureSuccessStatusCode();
205+
HttpResponseMessage response;
206+
try {
207+
response = await _httpClient.SendAsync(request);
208+
response.EnsureSuccessStatusCode();
209+
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
210+
throw new VertexAIRequestTimeoutException("Request timed out.", e);
211+
} catch (HttpRequestException e) {
212+
// TODO: Convert to a more precise exception when possible.
213+
throw new VertexAIException("HTTP request failed.", e);
214+
}
193215

194216
string result = await response.Content.ReadAsStringAsync();
195217

@@ -215,13 +237,16 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
215237
UnityEngine.Debug.Log("Request:\n" + bodyJson);
216238
#endif
217239

218-
HttpResponseMessage response =
219-
await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
220-
// TODO: Convert any timeout exception into a VertexAI equivalent
221-
// TODO: Convert any HttpRequestExceptions, see:
222-
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
223-
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
224-
response.EnsureSuccessStatusCode();
240+
HttpResponseMessage response;
241+
try {
242+
response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
243+
response.EnsureSuccessStatusCode();
244+
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
245+
throw new VertexAIRequestTimeoutException("Request timed out.", e);
246+
} catch (HttpRequestException e) {
247+
// TODO: Convert to a more precise exception when possible.
248+
throw new VertexAIException("HTTP request failed.", e);
249+
}
225250

226251
// We are expecting a Stream as the response, so handle that.
227252
using var stream = await response.Content.ReadAsStreamAsync();
@@ -242,9 +267,37 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
242267

243268
private async Task<CountTokensResponse> CountTokensAsyncInternal(
244269
IEnumerable<ModelContent> content) {
245-
// TODO: Implementation
246-
await Task.CompletedTask;
247-
throw new NotImplementedException();
270+
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":countTokens");
271+
272+
// Set the request headers
273+
SetRequestHeaders(request);
274+
275+
// Set the content
276+
string bodyJson = MakeCountTokensRequest(content);
277+
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");
278+
279+
#if FIREBASE_LOG_REST_CALLS
280+
UnityEngine.Debug.Log("CountTokensRequest:\n" + bodyJson);
281+
#endif
282+
283+
HttpResponseMessage response;
284+
try {
285+
response = await _httpClient.SendAsync(request);
286+
response.EnsureSuccessStatusCode();
287+
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
288+
throw new VertexAIRequestTimeoutException("Request timed out.", e);
289+
} catch (HttpRequestException e) {
290+
// TODO: Convert to a more precise exception when possible.
291+
throw new VertexAIException("HTTP request failed.", e);
292+
}
293+
294+
string result = await response.Content.ReadAsStringAsync();
295+
296+
#if FIREBASE_LOG_REST_CALLS
297+
UnityEngine.Debug.Log("CountTokensResponse:\n" + result);
298+
#endif
299+
300+
return CountTokensResponse.FromJson(result);
248301
}
249302

250303
private string GetURL() {
@@ -283,6 +336,25 @@ private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
283336

284337
return Json.Serialize(jsonDict);
285338
}
339+
340+
// CountTokensRequest is a subset of the full info needed for GenerateContent
341+
private string MakeCountTokensRequest(IEnumerable<ModelContent> contents) {
342+
Dictionary<string, object> jsonDict = new() {
343+
// Convert the Contents into a list of Json dictionaries
344+
["contents"] = contents.Select(c => c.ToJson()).ToList()
345+
};
346+
if (_generationConfig.HasValue) {
347+
jsonDict["generationConfig"] = _generationConfig?.ToJson();
348+
}
349+
if (_tools != null && _tools.Length > 0) {
350+
jsonDict["tools"] = _tools.Select(t => t.ToJson()).ToList();
351+
}
352+
if (_systemInstruction.HasValue) {
353+
jsonDict["systemInstruction"] = _systemInstruction?.ToJson();
354+
}
355+
356+
return Json.Serialize(jsonDict);
357+
}
286358
}
287359

288360
}

vertexai/src/ModalityTokenCount.cs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
using System.Collections.Generic;
18+
using Firebase.VertexAI.Internal;
19+
20+
namespace Firebase.VertexAI {
21+
22+
/// <summary>
23+
/// Content part modality.
24+
/// </summary>
25+
public enum ContentModality {
26+
/// <summary>
27+
/// A new and not yet supported value.
28+
/// </summary>
29+
Unknown = 0,
30+
/// <summary>
31+
/// Plain text.
32+
/// </summary>
33+
Text,
34+
/// <summary>
35+
/// Image.
36+
/// </summary>
37+
Image,
38+
/// <summary>
39+
/// Video.
40+
/// </summary>
41+
Video,
42+
/// <summary>
43+
/// Audio.
44+
/// </summary>
45+
Audio,
46+
/// <summary>
47+
/// Document, e.g. PDF.
48+
/// </summary>
49+
Document,
50+
}
51+
52+
/// <summary>
53+
/// Represents token counting info for a single modality.
54+
/// </summary>
55+
public readonly struct ModalityTokenCount {
56+
/// <summary>
57+
/// The modality associated with this token count.
58+
/// </summary>
59+
public ContentModality Modality { get; }
60+
/// <summary>
61+
/// The number of tokens counted.
62+
/// </summary>
63+
public int TokenCount { get; }
64+
65+
// Hidden constructor, users don't need to make this
66+
private ModalityTokenCount(ContentModality modality, int tokenCount) {
67+
Modality = modality;
68+
TokenCount = tokenCount;
69+
}
70+
71+
private static ContentModality ParseModality(string str) {
72+
return str switch {
73+
"TEXT" => ContentModality.Text,
74+
"IMAGE" => ContentModality.Image,
75+
"VIDEO" => ContentModality.Video,
76+
"AUDIO" => ContentModality.Audio,
77+
"DOCUMENT" => ContentModality.Document,
78+
_ => ContentModality.Unknown,
79+
};
80+
}
81+
82+
/// <summary>
83+
/// Intended for internal use only.
84+
/// This method is used for deserializing JSON responses and should not be called directly.
85+
/// </summary>
86+
internal static ModalityTokenCount FromJson(Dictionary<string, object> jsonDict) {
87+
return new ModalityTokenCount(
88+
jsonDict.ParseEnum("modality", ParseModality),
89+
jsonDict.ParseValue<int>("tokenCount"));
90+
}
91+
}
92+
93+
}

vertexai/src/ModalityTokenCount.cs.meta

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vertexai/src/VertexAIException.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
namespace Firebase.VertexAI {
2121

22-
public abstract class VertexAIException : Exception {
22+
public class VertexAIException : Exception {
2323
internal VertexAIException(string message) : base(message) { }
2424

2525
internal VertexAIException(string message, Exception exception) : base(message, exception) { }
@@ -62,6 +62,8 @@ internal VertexAIResponseStoppedException(GenerateContentResponse response) :
6262

6363
public class VertexAIRequestTimeoutException : VertexAIException {
6464
internal VertexAIRequestTimeoutException(string message) : base(message) { }
65+
66+
internal VertexAIRequestTimeoutException(string message, Exception e) : base(message, e) { }
6567
}
6668

6769
public class VertexAIInvalidLocationException : VertexAIException {

0 commit comments

Comments
 (0)