@@ -242,8 +242,78 @@ public CompletableFuture<GeneratedContent> generateContent(GenerativeModel model
242242 });
243243 }
244244
245+ /**
246+ * Embedding is a technique used to represent information as a list of floating point numbers in an array.
247+ * With Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form,
248+ * making it easier to compare and contrast embeddings.
249+ * For example, two texts that share a similar subject or sentiment should have similar embeddings,
250+ * which can be identified through mathematical comparison techniques such as cosine similarity.
251+ *
252+ * @param model to use. Currently, only {@link ModelVariant#TEXT_EMBEDDING_004} is allowed.
253+ * @param taskType Optional. Optional task type for which the embeddings will be used. For possible values, see {@link TaskType}
254+ * @param title Optional. An optional title for the text. Only applicable when TaskType is RETRIEVAL_DOCUMENT.
255+ * Note: Specifying a title for RETRIEVAL_DOCUMENT provides better quality embeddings for retrieval.
256+ * @param outputDimensionality Optional. Optional reduced dimension for the output embedding.
257+ * If set, excessive values in the output embedding are truncated from the end.
258+ * Supported by newer models since 2024, and the earlier model (models/embedding-001) cannot specify this value.
259+ * @return List of values
260+ * @apiNote Only {@link swiss.ameri.gemini.api.Content.TextContent} are allowed.
261+ */
262+ public CompletableFuture <List <ContentEmbedding >> embedContents (
263+ GenerativeModel model ,
264+ String taskType ,
265+ String title ,
266+ Long outputDimensionality
267+ ) {
268+ return execute (() -> {
269+
270+ var requests = convertGenerationContents (model )
271+ .stream ()
272+ .map (generationContent -> new EmbedContentRequest (
273+ model .modelName (),
274+ generationContent ,
275+ taskType ,
276+ title ,
277+ outputDimensionality
278+ ))
279+ .toList ();
280+
281+ var request = new BatchEmbedContentRequest (requests );
282+
283+ CompletableFuture <HttpResponse <String >> response = client .sendAsync (
284+ HttpRequest .newBuilder ()
285+ .POST (HttpRequest .BodyPublishers .ofString (
286+ jsonParser .toJson (request )
287+ ))
288+ .uri (URI .create ("%s/%s:batchEmbedContents?key=%s" .formatted (urlPrefix , model .modelName (), apiKey )))
289+ .build (),
290+ HttpResponse .BodyHandlers .ofString ()
291+ );
292+ return response
293+ .thenApply (HttpResponse ::body )
294+ .thenApply (body -> {
295+ try {
296+ BatchEmbedContentResponse becr = jsonParser .fromJson (body , BatchEmbedContentResponse .class );
297+ if (becr .embeddings () == null ) {
298+ throw new RuntimeException ();
299+ }
300+ return becr
301+ .embeddings ();
302+ } catch (Exception e ) {
303+ throw new RuntimeException ("Unexpected body:\n " + body , e );
304+ }
305+ });
306+
307+ });
308+ }
309+
245310 private static GenerateContentRequest convert (GenerativeModel model ) {
246- List <GenerationContent > generationContents = model .contents ().stream ()
311+ List <GenerationContent > generationContents = convertGenerationContents (model );
312+ return new GenerateContentRequest (model .modelName (), generationContents , model .safetySettings (), model .generationConfig ());
313+ }
314+
315+ private static List <GenerationContent > convertGenerationContents (GenerativeModel model ) {
316+ return model .contents ().stream ()
247317 .map (content -> {
248318 // change to "switch" over sealed type with jdk 21
249319 if (content instanceof Content .TextContent textContent ) {
@@ -294,7 +364,6 @@ private static GenerateContentRequest convert(GenerativeModel model) {
294364 }
295365 })
296366 .toList ();
297- return new GenerateContentRequest (model .modelName (), generationContents , model .safetySettings (), model .generationConfig ());
298367 }
299368
300369 private <T > T execute (ThrowingSupplier <T > supplier ) {
@@ -379,6 +448,35 @@ public record TypedSafetyRating(
379448
380449 }
381450
451+ /**
452+ * A list of floats representing an embedding.
453+ *
454+ * @param values A list of floats representing an embedding.
455+ */
456+ public record ContentEmbedding (
457+ List <Double > values
458+ ) {
459+ }
460+
461+ private record BatchEmbedContentRequest (
462+ List <EmbedContentRequest > requests
463+ ) {
464+ }
465+
466+ private record EmbedContentRequest (
467+ String model ,
468+ GenerationContent content ,
469+ String taskType ,
470+ String title ,
471+ Long outputDimensionality
472+ ) {
473+ }
474+
475+ private record BatchEmbedContentResponse (
476+ List <ContentEmbedding > embeddings
477+ ) {
478+ }
479+
382480 private record CountTokenRequest (
383481 GenerateContentRequest generateContentRequest
384482 ) {
0 commit comments