@@ -131,6 +131,41 @@ public List<SafetyRating> safetyRatings(UUID id) {
131131 .toList ();
132132 }
133133
134+
135+ /**
136+ * Runs a model's tokenizer on input content and returns the token count.
137+ * When using long prompts, it might be useful to count tokens before sending any content to the model.
138+ *
139+ * @param model to be analyzed
140+ * @return the token count
141+ */
142+ public CompletableFuture <Long > countTokens (GenerativeModel model ) {
143+ return execute (() -> {
144+ CompletableFuture <HttpResponse <String >> response = client .sendAsync (
145+ HttpRequest .newBuilder ()
146+ .POST (HttpRequest .BodyPublishers .ofString (
147+ jsonParser .toJson (new CountTokenRequest (convert (model )))
148+ ))
149+ .uri (URI .create ("%s/%s:countTokens?key=%s" .formatted (urlPrefix , model .modelName (), apiKey )))
150+ .build (),
151+ HttpResponse .BodyHandlers .ofString ()
152+ );
153+ return response
154+ .thenApply (HttpResponse ::body )
155+ .thenApply (body -> {
156+ try {
157+ var ctr = jsonParser .fromJson (body , CountTokenResponse .class );
158+ if (ctr .totalTokens () == null ) {
159+ throw new RuntimeException ("No token field in response" );
160+ }
161+ return ctr .totalTokens ();
162+ } catch (Exception e ) {
163+ throw new RuntimeException ("Unexpected body:\n " + body , e );
164+ }
165+ });
166+ });
167+ }
168+
134169 /**
135170 * Generates a response from Gemini API based on the given {@code model}. The response is streamed in chunks of text. The
136171 * stream items are delivered as they arrive.
@@ -210,7 +245,7 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
210245 private static GenerateContentRequest convert (GenerativeModel model ) {
211246 List <GenerationContent > generationContents = model .contents ().stream ()
212247 .map (content -> {
213- // todo change to "switch" over sealed type with jdk 21
248+ // change to "switch" over sealed type with jdk 21
214249 if (content instanceof Content .TextContent textContent ) {
215250 return new GenerationContent (
216251 textContent .role (),
@@ -259,7 +294,7 @@ private static GenerateContentRequest convert(GenerativeModel model) {
259294 }
260295 })
261296 .toList ();
262- return new GenerateContentRequest (generationContents , model .safetySettings (), model .generationConfig ());
297+ return new GenerateContentRequest (model . modelName (), generationContents , model .safetySettings (), model .generationConfig ());
263298 }
264299
265300 private <T > T execute (ThrowingSupplier <T > supplier ) {
@@ -344,6 +379,16 @@ public record TypedSafetyRating(
344379
345380 }
346381
382+ private record CountTokenRequest (
383+ GenerateContentRequest generateContentRequest
384+ ) {
385+ }
386+
387+ private record CountTokenResponse (
388+ Long totalTokens
389+ ) {
390+ }
391+
347392 private record GenerateContentResponse (
348393 UsageMetadata usageMetadata ,
349394 List <ResponseCandidate > candidates
@@ -359,6 +404,9 @@ private record ResponseCandidate(
359404 }
360405
361406 private record GenerateContentRequest (
407+ // for some reason, model is required for countToken, but not for the others.
408+ // But it seems to be acceptable for the others, so we just add it to all for now
409+ String model ,
362410 List <GenerationContent > contents ,
363411 List <SafetySetting > safetySettings ,
364412 GenerationConfig generationConfig
0 commit comments