Skip to content

Commit 246052d

Browse files
Merge pull request #1 from michael-ameri/feature/track-usage
feature/track-usage
2 parents 8b933d4 + dcdabc0 commit 246052d

File tree

3 files changed

+169
-34
lines changed

3 files changed

+169
-34
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,23 @@
99
import java.net.http.HttpRequest;
1010
import java.net.http.HttpResponse;
1111
import java.util.List;
12+
import java.util.Map;
13+
import java.util.Optional;
14+
import java.util.UUID;
1215
import java.util.concurrent.CompletableFuture;
16+
import java.util.concurrent.ConcurrentHashMap;
1317
import java.util.stream.Stream;
1418

15-
// todo decide if thread-safe or not, once responses are stored
19+
import static java.util.Collections.emptyList;
20+
1621

1722
/**
1823
* Entry point for all interactions with Gemini API.
24+
* Note that some methods store state (e.g. {@link #generateContent(GenerativeModel)} or ${@link #generateContentStream(GenerativeModel)}).
25+
* Call the {@link #close()} method to clean up the state.
26+
* This class is thread safe.
1927
*/
20-
public class GenAi {
28+
public class GenAi implements AutoCloseable {
2129

2230
private static final String STREAM_LINE_PREFIX = "data: ";
2331
private static final int STREAM_LINE_PREFIX_LENGTH = STREAM_LINE_PREFIX.length();
@@ -27,6 +35,7 @@ public class GenAi {
2735
private final String apiKey;
2836
private final HttpClient client;
2937
private final JsonParser jsonParser;
38+
private final Map<UUID, GenerateContentResponse> responseById = new ConcurrentHashMap<>();
3039

3140
public GenAi(
3241
String apiKey,
@@ -95,19 +104,47 @@ public Model getModel(String model) {
95104
});
96105
}
97106

107+
/**
108+
* Get the usage metadata of a {@link GeneratedContent#id()}.
109+
*
110+
* @param id of the corresponding {@link GeneratedContent}
111+
* @return the corresponding metadata, or an empty optional
112+
*/
113+
public Optional<UsageMetadata> usageMetadata(UUID id) {
114+
return Optional.ofNullable(responseById.get(id))
115+
.map(GenerateContentResponse::usageMetadata);
116+
}
117+
118+
/**
119+
* Get the safety ratings of a {@link GeneratedContent#id()}.
120+
*
121+
* @param id of the corresponding {@link GeneratedContent}
122+
* @return the corresponding safety ratings, or an empty optional
123+
*/
124+
public List<SafetyRating> safetyRatings(UUID id) {
125+
GenerateContentResponse response = responseById.get(id);
126+
if (response == null) {
127+
return emptyList();
128+
}
129+
return response.candidates().stream()
130+
.flatMap(candidate -> candidate.safetyRatings().stream())
131+
.toList();
132+
}
133+
98134
/**
99135
* Generates a response from Gemini API based on the given {@code model}. The response is streamed in chunks of text. The
100136
* stream items are delivered as they arrive.
137+
* Once the call has been completed, metadata and safety ratings can be obtained by calling
138+
* {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}. If those methods are called while the stream is still
139+
* active, the last available statistics are returned.
101140
*
102141
* @param model with the necessary information for Gemini API to generate content
103142
* @return A live stream of the response, as it arrives
104143
* @see #generateContent(GenerativeModel) which returns the whole response at once (asynchronously)
105144
*/
106145
public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
107-
// todo, keep responses in the state.
108-
// add up the usageMetadata
109-
// store the safety ratings
110146
return execute(() -> {
147+
UUID uuid = UUID.randomUUID();
111148
HttpRequest request = HttpRequest.newBuilder()
112149
.POST(HttpRequest.BodyPublishers.ofString(
113150
jsonParser.toJson(convert(model))
@@ -124,7 +161,9 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
124161
.map(line -> {
125162
try {
126163
var gcr = jsonParser.fromJson(line.substring(STREAM_LINE_PREFIX_LENGTH), GenerateContentResponse.class);
127-
return new GeneratedContent(gcr.candidates().get(0).content().parts().get(0).text());
164+
// each element can just replace the previous one
165+
this.responseById.put(uuid, gcr);
166+
return new GeneratedContent(uuid, gcr.candidates().get(0).content().parts().get(0).text());
128167
} catch (Exception e) {
129168
throw new RuntimeException("Unexpected line:\n" + line, e);
130169
}
@@ -134,6 +173,8 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
134173

135174
/**
136175
* Generates a response from Gemini API based on the given {@code model}.
176+
* Once the call has been completed, metadata and safety ratings can be obtained by calling
177+
* {@link #usageMetadata(UUID)} or {@link #safetyRatings(UUID)}
137178
*
138179
* @param model with the necessary information for Gemini API to generate content
139180
* @return a {@link CompletableFuture} which completes once the response from Gemini API has arrived. The {@link CompletableFuture}
@@ -142,6 +183,7 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
142183
*/
143184
public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model) {
144185
return execute(() -> {
186+
UUID uuid = UUID.randomUUID();
145187
CompletableFuture<HttpResponse<String>> response = client.sendAsync(
146188
HttpRequest.newBuilder()
147189
.POST(HttpRequest.BodyPublishers.ofString(
@@ -156,7 +198,8 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
156198
.thenApply(body -> {
157199
try {
158200
var gcr = jsonParser.fromJson(body, GenerateContentResponse.class);
159-
return new GeneratedContent(gcr.candidates().get(0).content().parts().get(0).text());
201+
responseById.put(uuid, gcr);
202+
return new GeneratedContent(uuid, gcr.candidates().get(0).content().parts().get(0).text());
160203
} catch (Exception e) {
161204
throw new RuntimeException("Unexpected body:\n" + body, e);
162205
}
@@ -230,14 +273,77 @@ private <T> T execute(ThrowingSupplier<T> supplier) {
230273
}
231274
}
232275

276+
/**
277+
* Clears the internal state.
278+
*/
279+
@Override
280+
public void close() {
281+
responseById.clear();
282+
}
283+
233284
/**
234285
* Content generated by Gemini API.
286+
*
287+
* @param id the id of the request, for subsequent queries regarding metadata of the query
235288
*/
236289
public record GeneratedContent(
290+
UUID id,
237291
String text
238292
) {
239293
}
240294

295+
/**
296+
* Usage metadata for a given request.
297+
*
298+
* @param promptTokenCount Number of tokens in the prompt.
299+
* @param candidatesTokenCount Total number of tokens for the generated response.
300+
* @param totalTokenCount Total token count for the generation request (prompt + candidates).
301+
*/
302+
public record UsageMetadata(
303+
int promptTokenCount,
304+
int candidatesTokenCount,
305+
int totalTokenCount
306+
) {
307+
}
308+
309+
/**
310+
* Safety rating for a given response.
311+
*
312+
* @param category The category for this rating. see {@link swiss.ameri.gemini.api.SafetySetting.HarmCategory}
313+
* @param probability The probability of harm for this content. see {@link swiss.ameri.gemini.api.SafetySetting.HarmProbability}
314+
*/
315+
public record SafetyRating(
316+
String category,
317+
String probability
318+
) {
319+
320+
/**
321+
* Convert the safety rating to a typed safety rating.
322+
* Might crash if Gemini API changes, and an enum value is missing.
323+
*
324+
* @return the TypedSafetyRating
325+
*/
326+
public TypedSafetyRating toTypedSafetyRating() {
327+
return new TypedSafetyRating(
328+
SafetySetting.HarmCategory.valueOf(category()),
329+
SafetySetting.HarmProbability.valueOf(probability())
330+
);
331+
}
332+
333+
/**
334+
* Typed values. This is done separately, since enum values might be missing compared to Gemini API
335+
*
336+
* @param harmCategory of this rating
337+
* @param probability of this rating
338+
*/
339+
public record TypedSafetyRating(
340+
SafetySetting.HarmCategory harmCategory,
341+
SafetySetting.HarmProbability probability
342+
) {
343+
}
344+
345+
}
346+
241347
private record GenerateContentResponse(
242348
UsageMetadata usageMetadata,
243349
List<ResponseCandidate> candidates
@@ -252,19 +358,6 @@ private record ResponseCandidate(
252358
) {
253359
}
254360

255-
private record SafetyRating(
256-
String category,
257-
String probability
258-
) {
259-
}
260-
261-
private record UsageMetadata(
262-
int promptTokenCount,
263-
int candidatesTokenCount,
264-
int totalTokenCount
265-
) {
266-
}
267-
268361
private record GenerateContentRequest(
269362
List<GenerationContent> contents,
270363
List<SafetySetting> safetySettings,

gemini-api/src/main/java/swiss/ameri/gemini/api/SafetySetting.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,37 @@ public enum HarmBlockThreshold {
9999
BLOCK_NONE
100100
}
101101

102+
/**
103+
* The probability that a piece of content is harmful.
104+
* The classification system gives the probability of the content being unsafe.
105+
* This does not indicate the severity of harm for a piece of content.
106+
*/
107+
public enum HarmProbability {
108+
109+
/**
110+
* Probability is unspecified.
111+
*/
112+
HARM_PROBABILITY_UNSPECIFIED,
113+
114+
/**
115+
* Content has a negligible chance of being unsafe.
116+
*/
117+
NEGLIGIBLE,
118+
119+
/**
120+
* Content has a low chance of being unsafe.
121+
*/
122+
LOW,
123+
124+
/**
125+
* Content has a medium chance of being unsafe.
126+
*/
127+
MEDIUM,
128+
129+
/**
130+
* Content has a high chance of being unsafe.
131+
*/
132+
HIGH
133+
}
134+
102135
}

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,17 @@ public static void main(String[] args) throws Exception {
3030
JsonParser parser = new GsonJsonParser();
3131
String apiKey = args[0];
3232

33-
GenAi genAi = new GenAi(
34-
apiKey,
35-
parser
36-
);
33+
try (var genAi = new GenAi(apiKey, parser)) {
34+
// each method represents an example usage
35+
listModels(genAi);
36+
getModel(genAi);
37+
generateContent(genAi);
38+
generateContentStream(genAi);
39+
multiChatTurn(genAi);
40+
textAndImage(genAi);
41+
}
42+
3743

38-
// each method represents an example usage
39-
listModels(genAi);
40-
getModel(genAi);
41-
generateContent(genAi);
42-
generateContentStream(genAi);
43-
multiChatTurn(genAi);
44-
textAndImage(genAi);
4544
}
4645

4746
private static void multiChatTurn(GenAi genAi) {
@@ -66,17 +65,27 @@ private static void multiChatTurn(GenAi genAi) {
6665
}
6766

6867
private static void generateContentStream(GenAi genAi) {
69-
System.out.println("----- Generate content (streaming)");
68+
System.out.println("----- Generate content (streaming) -- with usage meta data");
7069
var model = createStoryModel();
7170
genAi.generateContentStream(model)
72-
.forEach(System.out::println);
71+
.forEach(x -> {
72+
System.out.println(x);
73+
// note that the usage metadata is updated as it arrives
74+
System.out.println(genAi.usageMetadata(x.id()));
75+
System.out.println(genAi.safetyRatings(x.id()));
76+
});
7377
}
7478

7579
private static void generateContent(GenAi genAi) throws InterruptedException, ExecutionException, TimeoutException {
7680
var model = createStoryModel();
7781
System.out.println("----- Generate content (blocking)");
7882
genAi.generateContent(model)
79-
.thenAccept(System.out::println)
83+
.thenAccept(gcr -> {
84+
System.out.println(gcr);
85+
System.out.println("----- Generate content (blocking) usage meta data & safety ratings");
86+
System.out.println(genAi.usageMetadata(gcr.id()));
87+
System.out.println(genAi.safetyRatings(gcr.id()).stream().map(GenAi.SafetyRating::toTypedSafetyRating).toList());
88+
})
8089
.get(20, TimeUnit.SECONDS);
8190
}
8291

0 commit comments

Comments
 (0)