Skip to content

Commit c0ea3ac

Browse files
committed
embedContents
1 parent 9c71ec9 commit c0ea3ac

File tree

3 files changed

+181
-2
lines changed

3 files changed

+181
-2
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,78 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
242242
});
243243
}
244244

245+
/**
246+
* Embedding is a technique used to represent information as a list of floating point numbers in an array.
247+
* With Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form,
248+
* making it easier to compare and contrast embeddings.
249+
* For example, two texts that share a similar subject or sentiment should have similar embeddings,
250+
* which can be identified through mathematical comparison techniques such as cosine similarity.
251+
*
252+
* @param model to use. Currently, only {@link ModelVariant#TEXT_EMBEDDING_004} is allowed.
253+
* @param taskType Optional. Optional task type for which the embeddings will be used. For possible values, see {@link TaskType}
254+
* @param title Optional. An optional title for the text. Only applicable when TaskType is RETRIEVAL_DOCUMENT.
255+
* Note: Specifying a title for RETRIEVAL_DOCUMENT provides better quality embeddings for retrieval.
256+
* @param outputDimensionality Optional. Optional reduced dimension for the output embedding.
257+
* If set, excessive values in the output embedding are truncated from the end.
258+
* Supported by newer models since 2024, and the earlier model (models/embedding-001) cannot specify this value.
259+
* @return List of values
260+
* @apiNote Only {@link swiss.ameri.gemini.api.Content.TextContent} are allowed.
261+
*/
262+
public CompletableFuture<List<ContentEmbedding>> embedContents(
263+
GenerativeModel model,
264+
String taskType,
265+
String title,
266+
Long outputDimensionality
267+
) {
268+
return execute(() -> {
269+
270+
var requests = convertGenerationContents(model)
271+
.stream()
272+
.map(generationContent -> new EmbedContentRequest(
273+
model.modelName(),
274+
generationContent,
275+
taskType,
276+
title,
277+
outputDimensionality
278+
))
279+
.toList();
280+
281+
var request = new BatchEmbedContentRequest(requests);
282+
283+
CompletableFuture<HttpResponse<String>> response = client.sendAsync(
284+
HttpRequest.newBuilder()
285+
.POST(HttpRequest.BodyPublishers.ofString(
286+
jsonParser.toJson(request)
287+
))
288+
.uri(URI.create("%s/%s:batchEmbedContents?key=%s".formatted(urlPrefix, model.modelName(), apiKey)))
289+
.build(),
290+
HttpResponse.BodyHandlers.ofString()
291+
);
292+
return response
293+
.thenApply(HttpResponse::body)
294+
.thenApply(body -> {
295+
try {
296+
BatchEmbedContentResponse becr = jsonParser.fromJson(body, BatchEmbedContentResponse.class);
297+
if (becr.embeddings() == null) {
298+
throw new RuntimeException();
299+
}
300+
return becr
301+
.embeddings();
302+
} catch (Exception e) {
303+
throw new RuntimeException("Unexpected body:\n" + body, e);
304+
}
305+
});
306+
307+
});
308+
}
309+
245310
private static GenerateContentRequest convert(GenerativeModel model) {
246-
List<GenerationContent> generationContents = model.contents().stream()
311+
List<GenerationContent> generationContents = convertGenerationContents(model);
312+
return new GenerateContentRequest(model.modelName(), generationContents, model.safetySettings(), model.generationConfig());
313+
}
314+
315+
private static List<GenerationContent> convertGenerationContents(GenerativeModel model) {
316+
return model.contents().stream()
247317
.map(content -> {
248318
// change to "switch" over sealed type with jdk 21
249319
if (content instanceof Content.TextContent textContent) {
@@ -294,7 +364,6 @@ private static GenerateContentRequest convert(GenerativeModel model) {
294364
}
295365
})
296366
.toList();
297-
return new GenerateContentRequest(model.modelName(), generationContents, model.safetySettings(), model.generationConfig());
298367
}
299368

300369
private <T> T execute(ThrowingSupplier<T> supplier) {
@@ -379,6 +448,35 @@ public record TypedSafetyRating(
379448

380449
}
381450

451+
/**
452+
* A list of floats representing an embedding.
453+
*
454+
* @param values A list of floats representing an embedding.
455+
*/
456+
public record ContentEmbedding(
457+
List<Double> values
458+
) {
459+
}
460+
461+
private record BatchEmbedContentRequest(
462+
List<EmbedContentRequest> requests
463+
) {
464+
}
465+
466+
private record EmbedContentRequest(
467+
String model,
468+
GenerationContent content,
469+
String taskType,
470+
String title,
471+
Long outputDimensionality
472+
) {
473+
}
474+
475+
private record BatchEmbedContentResponse(
476+
List<ContentEmbedding> embeddings
477+
) {
478+
}
479+
382480
private record CountTokenRequest(
383481
GenerateContentRequest generateContentRequest
384482
) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package swiss.ameri.gemini.api;
2+
3+
/**
4+
* Type of task for which the embedding will be used.
5+
*/
6+
public enum TaskType {
7+
/**
8+
* Unset value, which will default to one of the other enum values.
9+
*/
10+
TASK_TYPE_UNSPECIFIED,
11+
12+
/**
13+
* Specifies the given text is a query in a search/retrieval setting.
14+
*/
15+
RETRIEVAL_QUERY,
16+
17+
/**
18+
* Specifies the given text is a document from the corpus being searched.
19+
*/
20+
RETRIEVAL_DOCUMENT,
21+
22+
/**
23+
* Specifies the given text will be used for Semantic Textual Similarity (STS).
24+
*/
25+
SEMANTIC_SIMILARITY,
26+
27+
/**
28+
* Specifies that the given text will be classified.
29+
*/
30+
CLASSIFICATION,
31+
32+
/**
33+
* Specifies that the embeddings will be used for clustering.
34+
*/
35+
CLUSTERING,
36+
37+
/**
38+
* Specifies that the given text will be used for question answering.
39+
*/
40+
QUESTION_ANSWERING,
41+
42+
/**
43+
* Specifies that the given text will be used for fact verification.
44+
*/
45+
FACT_VERIFICATION
46+
}

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.io.IOException;
88
import java.io.InputStream;
99
import java.util.Base64;
10+
import java.util.List;
1011
import java.util.concurrent.ExecutionException;
1112
import java.util.concurrent.TimeUnit;
1213
import java.util.concurrent.TimeoutException;
@@ -39,11 +40,45 @@ public static void main(String[] args) throws Exception {
3940
generateContentStream(genAi);
4041
multiChatTurn(genAi);
4142
textAndImage(genAi);
43+
embedContents(genAi);
4244
}
4345

4446

4547
}
4648

49+
private static void embedContents(GenAi genAi) {
50+
System.out.println("----- embed contents");
51+
var model = GenerativeModel.builder()
52+
.modelName(ModelVariant.TEXT_EMBEDDING_004)
53+
.addContent(Content.textContent(
54+
Content.Role.USER,
55+
"Write a 50 word story about a magic backpack."
56+
))
57+
.addContent(Content.textContent(
58+
Content.Role.MODEL,
59+
"bla bla bla bla"
60+
))
61+
.addSafetySetting(SafetySetting.of(
62+
SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
63+
SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH
64+
))
65+
.generationConfig(new GenerationConfig(
66+
null,
67+
null,
68+
null,
69+
null,
70+
null,
71+
null,
72+
null
73+
))
74+
.build();
75+
76+
List<GenAi.ContentEmbedding> embeddings = genAi.embedContents(model, null, null, null).join();
77+
System.out.println("Embedding count: " + embeddings.size());
78+
System.out.println("Values per embedding: " + embeddings.stream().map(GenAi.ContentEmbedding::values).map(List::size).toList());
79+
80+
}
81+
4782
private static void countTokens(GenAi genAi) {
4883
System.out.println("----- count tokens");
4984
var model = createStoryModel();

0 commit comments

Comments
 (0)