Skip to content

Commit 9c71ec9

Browse files
committed
countTokens
1 parent 246052d commit 9c71ec9

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,41 @@ public List<SafetyRating> safetyRatings(UUID id) {
131131
.toList();
132132
}
133133

134+
135+
/**
136+
* Runs a model's tokenizer on input content and returns the token count.
137+
* When using long prompts, it might be useful to count tokens before sending any content to the model.
138+
*
139+
* @param model to be analyzed
140+
* @return the token count
141+
*/
142+
public CompletableFuture<Long> countTokens(GenerativeModel model) {
143+
return execute(() -> {
144+
CompletableFuture<HttpResponse<String>> response = client.sendAsync(
145+
HttpRequest.newBuilder()
146+
.POST(HttpRequest.BodyPublishers.ofString(
147+
jsonParser.toJson(new CountTokenRequest(convert(model)))
148+
))
149+
.uri(URI.create("%s/%s:countTokens?key=%s".formatted(urlPrefix, model.modelName(), apiKey)))
150+
.build(),
151+
HttpResponse.BodyHandlers.ofString()
152+
);
153+
return response
154+
.thenApply(HttpResponse::body)
155+
.thenApply(body -> {
156+
try {
157+
var ctr = jsonParser.fromJson(body, CountTokenResponse.class);
158+
if (ctr.totalTokens() == null) {
159+
throw new RuntimeException("No token field in response");
160+
}
161+
return ctr.totalTokens();
162+
} catch (Exception e) {
163+
throw new RuntimeException("Unexpected body:\n" + body, e);
164+
}
165+
});
166+
});
167+
}
168+
134169
/**
135170
* Generates a response from Gemini API based on the given {@code model}. The response is streamed in chunks of text. The
136171
* stream items are delivered as they arrive.
@@ -210,7 +245,7 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
210245
private static GenerateContentRequest convert(GenerativeModel model) {
211246
List<GenerationContent> generationContents = model.contents().stream()
212247
.map(content -> {
213-
// todo change to "switch" over sealed type with jdk 21
248+
// change to "switch" over sealed type with jdk 21
214249
if (content instanceof Content.TextContent textContent) {
215250
return new GenerationContent(
216251
textContent.role(),
@@ -259,7 +294,7 @@ private static GenerateContentRequest convert(GenerativeModel model) {
259294
}
260295
})
261296
.toList();
262-
return new GenerateContentRequest(generationContents, model.safetySettings(), model.generationConfig());
297+
return new GenerateContentRequest(model.modelName(), generationContents, model.safetySettings(), model.generationConfig());
263298
}
264299

265300
private <T> T execute(ThrowingSupplier<T> supplier) {
@@ -344,6 +379,16 @@ public record TypedSafetyRating(
344379

345380
}
346381

382+
private record CountTokenRequest(
383+
GenerateContentRequest generateContentRequest
384+
) {
385+
}
386+
387+
private record CountTokenResponse(
388+
Long totalTokens
389+
) {
390+
}
391+
347392
private record GenerateContentResponse(
348393
UsageMetadata usageMetadata,
349394
List<ResponseCandidate> candidates
@@ -359,6 +404,9 @@ private record ResponseCandidate(
359404
}
360405

361406
private record GenerateContentRequest(
407+
// for some reason, model is required for countToken, but not for the others.
408+
// But it seems to be acceptable for the others, so we just add it to all for now
409+
String model,
362410
List<GenerationContent> contents,
363411
List<SafetySetting> safetySettings,
364412
GenerationConfig generationConfig

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public static void main(String[] args) throws Exception {
3434
// each method represents an example usage
3535
listModels(genAi);
3636
getModel(genAi);
37+
countTokens(genAi);
3738
generateContent(genAi);
3839
generateContentStream(genAi);
3940
multiChatTurn(genAi);
@@ -43,6 +44,14 @@ public static void main(String[] args) throws Exception {
4344

4445
}
4546

47+
private static void countTokens(GenAi genAi) {
48+
System.out.println("----- count tokens");
49+
var model = createStoryModel();
50+
Long result = genAi.countTokens(model)
51+
.join();
52+
System.out.println("Tokens: " + result);
53+
}
54+
4655
private static void multiChatTurn(GenAi genAi) {
4756
System.out.println("----- multi turn chat");
4857
GenerativeModel chatModel = GenerativeModel.builder()

0 commit comments

Comments
 (0)