@@ -41,14 +41,14 @@ public static void main(String[] args) throws IOException {
4141 // https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
4242 String endpoint = "us-central1-aiplatform.googleapis.com:443" ;
4343 String project = "YOUR_PROJECT_ID" ;
44- String model = "text -embedding-005 " ;
44+ String model = "gemini -embedding-001 " ;
4545 predictTextEmbeddings (
4646 endpoint ,
4747 project ,
4848 model ,
4949 List .of ("banana bread?" , "banana muffins?" ),
5050 "QUESTION_ANSWERING" ,
51- OptionalInt .of (256 ));
51+ OptionalInt .of (3072 ));
5252 }
5353
5454 // Gets text embeddings from a pretrained, foundational model.
@@ -67,37 +67,40 @@ public static List<List<Float>> predictTextEmbeddings(
6767 EndpointName endpointName =
6868 EndpointName .ofProjectLocationPublisherModelName (project , location , "google" , model );
6969
70+ List <List <Float >> floats = new ArrayList <>();
7071 // You can use this prediction service client for multiple requests.
7172 try (PredictionServiceClient client = PredictionServiceClient .create (settings )) {
72- PredictRequest .Builder request =
73- PredictRequest .newBuilder ().setEndpoint (endpointName .toString ());
74- if (outputDimensionality .isPresent ()) {
75- request .setParameters (
76- Value .newBuilder ()
77- .setStructValue (
78- Struct .newBuilder ()
79- .putFields ("outputDimensionality" , valueOf (outputDimensionality .getAsInt ()))
80- .build ()));
81- }
73+ // gemini-embedding-001 takes one input at a time.
8274 for (int i = 0 ; i < texts .size (); i ++) {
75+ PredictRequest .Builder request =
76+ PredictRequest .newBuilder ().setEndpoint (endpointName .toString ());
77+ if (outputDimensionality .isPresent ()) {
78+ request .setParameters (
79+ Value .newBuilder ()
80+ .setStructValue (
81+ Struct .newBuilder ()
82+ .putFields (
83+ "outputDimensionality" , valueOf (outputDimensionality .getAsInt ()))
84+ .build ()));
85+ }
8386 request .addInstances (
8487 Value .newBuilder ()
8588 .setStructValue (
8689 Struct .newBuilder ()
8790 .putFields ("content" , valueOf (texts .get (i )))
8891 .putFields ("task_type" , valueOf (task ))
8992 .build ()));
90- }
91- PredictResponse response = client . predict ( request . build ());
92- List < List < Float >> floats = new ArrayList <>();
93- for ( Value prediction : response . getPredictionsList ()) {
94- Value embeddings = prediction .getStructValue ().getFieldsOrThrow ("embeddings " );
95- Value values = embeddings . getStructValue (). getFieldsOrThrow ( "values" );
96- floats . add (
97- values . getListValue (). getValuesList (). stream ( )
98- .map (Value :: getNumberValue )
99- . map ( Double :: floatValue )
100- . collect ( toList ()));
93+ PredictResponse response = client . predict ( request . build ());
94+
95+ for ( Value prediction : response . getPredictionsList ()) {
96+ Value embeddings = prediction . getStructValue (). getFieldsOrThrow ( "embeddings" );
97+ Value values = embeddings .getStructValue ().getFieldsOrThrow ("values " );
98+ floats . add (
99+ values . getListValue (). getValuesList (). stream ()
100+ . map ( Value :: getNumberValue )
101+ .map (Double :: floatValue )
102+ . collect ( toList ()));
103+ }
101104 }
102105 return floats ;
103106 }
0 commit comments