Skip to content

Commit 9552422

Browse files
authored
Fixing bug setting index when parsing Google Vertex AI results (#117287) (#117358)
* Using record ID as index value when parsing Google Vertex AI rerank results * Update docs/changelog/117287.yaml * PR feedback
1 parent d95c003 commit 9552422

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

docs/changelog/117287.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117287
2+
summary: Fixing bug setting index when parsing Google Vertex AI results
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
public class GoogleVertexAiRerankResponseEntity {
3131

3232
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Google Vertex AI rerank response";
33+
private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Google Vertex AI rerank "
34+
+ "response but received [%s]";
3335

3436
/**
3537
* Parses the Google Vertex AI rerank response.
@@ -109,14 +111,27 @@ private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser)
109111
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
110112
}
111113

112-
return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
114+
if (parsedRankedDoc.id == null) {
115+
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName()));
116+
}
117+
118+
try {
119+
return new RankedDocsResults.RankedDoc(
120+
Integer.parseInt(parsedRankedDoc.id),
121+
parsedRankedDoc.score,
122+
parsedRankedDoc.content
123+
);
124+
} catch (NumberFormatException e) {
125+
throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
126+
}
113127
});
114128
}
115129

116-
private record RankedDoc(@Nullable Float score, @Nullable String content) {
130+
private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) {
117131

118132
private static final ParseField CONTENT = new ParseField("content");
119133
private static final ParseField SCORE = new ParseField("score");
134+
private static final ParseField ID = new ParseField("id");
120135
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
121136
"google_vertex_ai_rerank_response",
122137
true,
@@ -126,6 +141,7 @@ private record RankedDoc(@Nullable Float score, @Nullable String content) {
126141
static {
127142
PARSER.declareString(Builder::setContent, CONTENT);
128143
PARSER.declareFloat(Builder::setScore, SCORE);
144+
PARSER.declareString(Builder::setId, ID);
129145
}
130146

131147
public static RankedDoc parse(XContentParser parser) {
@@ -137,6 +153,7 @@ private static final class Builder {
137153

138154
private String content;
139155
private Float score;
156+
private String id;
140157

141158
private Builder() {}
142159

@@ -150,8 +167,13 @@ public Builder setContent(String content) {
150167
return this;
151168
}
152169

170+
public Builder setId(String id) {
171+
this.id = id;
172+
return this;
173+
}
174+
153175
public RankedDoc build() {
154-
return new RankedDoc(score, content);
176+
return new RankedDoc(score, content, id);
155177
}
156178
}
157179
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
3939
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
4040
);
4141

42-
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"))));
42+
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
4343
}
4444

4545
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -68,7 +68,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException
6868

6969
assertThat(
7070
parsedResults.getRankedDocs(),
71-
is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
71+
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
7272
);
7373
}
7474

@@ -161,4 +161,37 @@ public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {
161161

162162
assertThat(thrownException.getMessage(), is("Failed to find required field [score] in Google Vertex AI rerank response"));
163163
}
164+
165+
public void testFromResponse_FailsWhenIDFieldIsNotInteger() {
166+
String responseJson = """
167+
{
168+
"records": [
169+
{
170+
"id": "abcd",
171+
"title": "title 2",
172+
"content": "content 2",
173+
"score": 0.97
174+
},
175+
{
176+
"id": "1",
177+
"title": "title 1",
178+
"content": "content 1",
179+
"score": 0.96
180+
}
181+
]
182+
}
183+
""";
184+
185+
var thrownException = expectThrows(
186+
IllegalStateException.class,
187+
() -> GoogleVertexAiRerankResponseEntity.fromResponse(
188+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
189+
)
190+
);
191+
192+
assertThat(
193+
thrownException.getMessage(),
194+
is("Expected numeric value for record ID field in Google Vertex AI rerank response but received [abcd]")
195+
);
196+
}
164197
}

0 commit comments

Comments
 (0)