99import java .net .http .HttpRequest ;
1010import java .net .http .HttpResponse ;
1111import java .util .List ;
12+ import java .util .Map ;
13+ import java .util .Optional ;
14+ import java .util .UUID ;
1215import java .util .concurrent .CompletableFuture ;
16+ import java .util .concurrent .ConcurrentHashMap ;
1317import java .util .stream .Stream ;
1418
15- // todo decide if thread-safe or not, once responses are stored
19+ import static java .util .Collections .emptyList ;
20+
1621
1722/**
1823 * Entry point for all interactions with Gemini API.
24+ * Note that some methods store state (e.g. {@link #generateContent(GenerativeModel)} or ${@link #generateContentStream(GenerativeModel)}).
25+ * Call the {@link #close()} method to clean up the state.
26+ * This class is thread safe.
1927 */
20- public class GenAi {
28+ public class GenAi implements AutoCloseable {
2129
2230 private static final String STREAM_LINE_PREFIX = "data: " ;
2331 private static final int STREAM_LINE_PREFIX_LENGTH = STREAM_LINE_PREFIX .length ();
@@ -27,6 +35,7 @@ public class GenAi {
2735 private final String apiKey ;
2836 private final HttpClient client ;
2937 private final JsonParser jsonParser ;
38+ private final Map <UUID , GenerateContentResponse > responseById = new ConcurrentHashMap <>();
3039
3140 public GenAi (
3241 String apiKey ,
@@ -95,19 +104,47 @@ public Model getModel(String model) {
95104 });
96105 }
97106
107+ /**
108+ * Get the usage metadata of a {@link GeneratedContent#id()}.
109+ *
110+ * @param id of the corresponding {@link GeneratedContent}
111+ * @return the corresponding metadata, or an empty optional
112+ */
113+ public Optional <UsageMetadata > usageMetadata (UUID id ) {
114+ return Optional .ofNullable (responseById .get (id ))
115+ .map (GenerateContentResponse ::usageMetadata );
116+ }
117+
118+ /**
119+ * Get the safety ratings of a {@link GeneratedContent#id()}.
120+ *
121+ * @param id of the corresponding {@link GeneratedContent}
122+ * @return the corresponding safety ratings, or an empty optional
123+ */
124+ public List <SafetyRating > safetyRatings (UUID id ) {
125+ GenerateContentResponse response = responseById .get (id );
126+ if (response == null ) {
127+ return emptyList ();
128+ }
129+ return response .candidates ().stream ()
130+ .flatMap (candidate -> candidate .safetyRatings ().stream ())
131+ .toList ();
132+ }
133+
98134 /**
99135 * Generates a response from Gemini API based on the given {@code model}. The response is streamed in chunks of text. The
100136 * stream items are delivered as they arrive.
137+ * Once the call has been completed, metadata and safety ratings can be obtained by calling
138+ * {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}. If those methods are called while the stream is still
139+ * active, the last available statistics are returned.
101140 *
102141 * @param model with the necessary information for Gemini API to generate content
103142 * @return A live stream of the response, as it arrives
104143 * @see #generateContent(GenerativeModel) which returns the whole response at once (asynchronously)
105144 */
106145 public Stream <GeneratedContent > generateContentStream (GenerativeModel model ) {
107- // todo, keep responses in the state.
108- // add up the usageMetadata
109- // store the safety ratings
110146 return execute (() -> {
147+ UUID uuid = UUID .randomUUID ();
111148 HttpRequest request = HttpRequest .newBuilder ()
112149 .POST (HttpRequest .BodyPublishers .ofString (
113150 jsonParser .toJson (convert (model ))
@@ -124,7 +161,9 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
124161 .map (line -> {
125162 try {
126163 var gcr = jsonParser .fromJson (line .substring (STREAM_LINE_PREFIX_LENGTH ), GenerateContentResponse .class );
127- return new GeneratedContent (gcr .candidates ().get (0 ).content ().parts ().get (0 ).text ());
164+ // each element can just replace the previous one
165+ this .responseById .put (uuid , gcr );
166+ return new GeneratedContent (uuid , gcr .candidates ().get (0 ).content ().parts ().get (0 ).text ());
128167 } catch (Exception e ) {
129168 throw new RuntimeException ("Unexpected line:\n " + line , e );
130169 }
@@ -134,6 +173,8 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
134173
135174 /**
136175 * Generates a response from Gemini API based on the given {@code model}.
176+ * Once the call has been completed, metadata and safety ratings can be obtained by calling
177+ * {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}
137178 *
138179 * @param model with the necessary information for Gemini API to generate content
139180 * @return a {@link CompletableFuture} which completes once the response from Gemini API has arrived. The {@link CompletableFuture}
@@ -142,6 +183,7 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
142183 */
143184 public CompletableFuture <GeneratedContent > generateContent (GenerativeModel model ) {
144185 return execute (() -> {
186+ UUID uuid = UUID .randomUUID ();
145187 CompletableFuture <HttpResponse <String >> response = client .sendAsync (
146188 HttpRequest .newBuilder ()
147189 .POST (HttpRequest .BodyPublishers .ofString (
@@ -156,7 +198,8 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
156198 .thenApply (body -> {
157199 try {
158200 var gcr = jsonParser .fromJson (body , GenerateContentResponse .class );
159- return new GeneratedContent (gcr .candidates ().get (0 ).content ().parts ().get (0 ).text ());
201+ responseById .put (uuid , gcr );
202+ return new GeneratedContent (uuid , gcr .candidates ().get (0 ).content ().parts ().get (0 ).text ());
160203 } catch (Exception e ) {
161204 throw new RuntimeException ("Unexpected body:\n " + body , e );
162205 }
@@ -230,14 +273,77 @@ private <T> T execute(ThrowingSupplier<T> supplier) {
230273 }
231274 }
232275
276+ /**
277+ * Clears the internal state.
278+ */
279+ @ Override
280+ public void close () {
281+ responseById .clear ();
282+ }
283+
233284 /**
234285 * Content generated by Gemini API.
286+ *
287+ * @param id the id of the request, for subsequent queries regarding metadata of the query
235288 */
236289 public record GeneratedContent (
290+ UUID id ,
237291 String text
238292 ) {
239293 }
240294
295+ /**
296+ * Usage metadata for a given request.
297+ *
298+ * @param promptTokenCount Number of tokens in the prompt.
299+ * @param candidatesTokenCount Total number of tokens for the generated response.
300+ * @param totalTokenCount Total token count for the generation request (prompt + candidates).
301+ */
302+ public record UsageMetadata (
303+ int promptTokenCount ,
304+ int candidatesTokenCount ,
305+ int totalTokenCount
306+ ) {
307+ }
308+
309+ /**
310+ * Safety rating for a given response.
311+ *
312+ * @param category The category for this rating. see {@link swiss.ameri.gemini.api.SafetySetting.HarmCategory}
313+ * @param probability The probability of harm for this content. see {@link swiss.ameri.gemini.api.SafetySetting.HarmProbability}
314+ */
315+ public record SafetyRating (
316+ String category ,
317+ String probability
318+ ) {
319+
320+ /**
321+ * Convert the safety rating to a typed safety rating.
322+ * Might crash if Gemini API changes, and an enum value is missing.
323+ *
324+ * @return the TypedSafetyRating
325+ */
326+ public TypedSafetyRating toTypedSafetyRating () {
327+ return new TypedSafetyRating (
328+ SafetySetting .HarmCategory .valueOf (category ()),
329+ SafetySetting .HarmProbability .valueOf (probability ())
330+ );
331+ }
332+
333+ /**
334+ * Typed values. This is done separately, since enum values might be missing compared to Gemini API
335+ *
336+ * @param harmCategory of this rating
337+ * @param probability of this rating
338+ */
339+ public record TypedSafetyRating (
340+ SafetySetting .HarmCategory harmCategory ,
341+ SafetySetting .HarmProbability probability
342+ ) {
343+ }
344+
345+ }
346+
241347 private record GenerateContentResponse (
242348 UsageMetadata usageMetadata ,
243349 List <ResponseCandidate > candidates
@@ -252,19 +358,6 @@ private record ResponseCandidate(
252358 ) {
253359 }
254360
255- private record SafetyRating (
256- String category ,
257- String probability
258- ) {
259- }
260-
261- private record UsageMetadata (
262- int promptTokenCount ,
263- int candidatesTokenCount ,
264- int totalTokenCount
265- ) {
266- }
267-
268361 private record GenerateContentRequest (
269362 List <GenerationContent > contents ,
270363 List <SafetySetting > safetySettings ,
0 commit comments