Skip to content

Commit f4beed6

Browse files
sobychackonamsoo2
authored andcommitted
spring-projectsGH-2168: Fix task type property name in Vertex AI embedding requests
Fixes: spring-projects#2168 - Change property name from 'taskType' to 'task_type' in VertexAiEmbeddingUtils to match Google API expectations - Add integration tests to verify task type behavior matches Google SDK - Add missing auto truncate option copying in VertexAiTextEmbeddingOptions Signed-off-by: Soby Chacko <[email protected]> Signed-off-by: minsoo.nam <[email protected]>
1 parent f37dd6b commit f4beed6

File tree

3 files changed

+123
-2
lines changed

3 files changed

+123
-2
lines changed

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public Struct build() {
140140
Struct.Builder textBuilder = Struct.newBuilder();
141141
textBuilder.putFields("content", valueOf(this.content));
142142
if (StringUtils.hasText(this.taskType)) {
143-
textBuilder.putFields("taskType", valueOf(this.taskType));
143+
textBuilder.putFields("task_type", valueOf(this.taskType));
144144
}
145145
if (StringUtils.hasText(this.title)) {
146146
textBuilder.putFields("title", valueOf(this.title));

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ public Builder from(VertexAiTextEmbeddingOptions fromOptions) {
187187
if (fromOptions.getTaskType() != null) {
188188
this.options.setTaskType(fromOptions.getTaskType());
189189
}
190+
if (fromOptions.getAutoTruncate() != null) {
191+
this.options.setAutoTruncate(fromOptions.getAutoTruncate());
192+
}
190193
if (StringUtils.hasText(fromOptions.getTitle())) {
191194
this.options.setTitle(fromOptions.getTitle());
192195
}

models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,13 @@
1818

1919
import java.util.List;
2020

21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictRequest;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
25+
import com.google.protobuf.Struct;
26+
import com.google.protobuf.Value;
27+
import org.junit.jupiter.api.Test;
2128
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2229
import org.junit.jupiter.params.ParameterizedTest;
2330
import org.junit.jupiter.params.provider.ValueSource;
@@ -30,6 +37,7 @@
3037
import org.springframework.boot.test.context.SpringBootTest;
3138
import org.springframework.context.annotation.Bean;
3239

40+
import static java.util.stream.Collectors.toList;
3341
import static org.assertj.core.api.Assertions.assertThat;
3442

3543
@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class)
@@ -65,6 +73,116 @@ void defaultEmbedding(String modelName) {
6573
assertThat(this.embeddingModel.dimensions()).isEqualTo(768);
6674
}
6775

76+
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
77+
@Test
78+
void testTaskTypeProperty() {
79+
// Use text-embedding-005 model
80+
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
81+
.model("text-embedding-005")
82+
.taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT)
83+
.build();
84+
85+
String text = "Test text for embedding";
86+
87+
// Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type
88+
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
89+
90+
assertThat(embeddingResponse.getResults()).hasSize(1);
91+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull();
92+
93+
// Get the embedding result
94+
float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput();
95+
96+
// Now generate the same embedding using Google SDK directly with
97+
// RETRIEVAL_DOCUMENT
98+
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
99+
100+
// Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the
101+
// default)
102+
float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY");
103+
104+
// Spring AI embedding should match with what gets generated by Google SDK with
105+
// RETRIEVAL_DOCUMENT task type.
106+
assertThat(springAiEmbedding)
107+
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding")
108+
.isEqualTo(googleSdkDocumentEmbedding);
109+
110+
// Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match
111+
// with what gets generated by
112+
// Google SDK with RETRIEVAL_QUERY task type.
113+
assertThat(springAiEmbedding)
114+
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding")
115+
.isNotEqualTo(googleSdkQueryEmbedding);
116+
}
117+
118+
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
119+
@Test
120+
void testDefaultTaskTypeBehavior() {
121+
// Test default behavior without explicitly setting task type
122+
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
123+
.model("text-embedding-005")
124+
.build();
125+
126+
String text = "Test text for default embedding";
127+
128+
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
129+
130+
assertThat(embeddingResponse.getResults()).hasSize(1);
131+
132+
float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput();
133+
134+
// According to documentation, default should be RETRIEVAL_DOCUMENT
135+
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
136+
137+
assertThat(springAiDefaultEmbedding)
138+
.as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding")
139+
.isEqualTo(googleSdkDocumentEmbedding);
140+
}
141+
142+
private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) {
143+
try {
144+
String endpoint = String.format("%s-aiplatform.googleapis.com:443",
145+
System.getenv("VERTEX_AI_GEMINI_LOCATION"));
146+
String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
147+
148+
PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
149+
150+
EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project,
151+
System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005");
152+
153+
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
154+
PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
155+
156+
request.addInstances(Value.newBuilder()
157+
.setStructValue(Struct.newBuilder()
158+
.putFields("content", Value.newBuilder().setStringValue(text).build())
159+
.putFields("task_type", Value.newBuilder().setStringValue(taskType).build())
160+
.build())
161+
.build());
162+
163+
var prediction = client.predict(request.build()).getPredictionsList().get(0);
164+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
165+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
166+
167+
List<Float> floatList = values.getListValue()
168+
.getValuesList()
169+
.stream()
170+
.map(Value::getNumberValue)
171+
.map(Double::floatValue)
172+
.collect(toList());
173+
174+
float[] floatArray = new float[floatList.size()];
175+
for (int i = 0; i < floatList.size(); i++) {
176+
floatArray[i] = floatList.get(i);
177+
}
178+
return floatArray;
179+
}
180+
}
181+
catch (Exception e) {
182+
throw new RuntimeException("Failed to get embedding from Google SDK", e);
183+
}
184+
}
185+
68186
@SpringBootConfiguration
69187
static class Config {
70188

0 commit comments

Comments
 (0)