Skip to content

Commit 08504a8

Browse files
Fixing rerank response parser
1 parent 8007380 commit 08504a8

File tree

2 files changed

+165
-139
lines changed

2 files changed

+165
-139
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java

Lines changed: 67 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,27 @@
77

88
package org.elasticsearch.xpack.inference.services.jinaai.response;
99

10-
import org.apache.logging.log4j.LogManager;
11-
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
10+
import org.elasticsearch.core.Nullable;
1311
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.ConstructingObjectParser;
13+
import org.elasticsearch.xcontent.ObjectParser;
14+
import org.elasticsearch.xcontent.ParseField;
1415
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParseException;
1517
import org.elasticsearch.xcontent.XContentParser;
1618
import org.elasticsearch.xcontent.XContentParserConfiguration;
1719
import org.elasticsearch.xcontent.XContentType;
1820
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1921
import org.elasticsearch.xpack.inference.external.http.HttpResult;
2022

2123
import java.io.IOException;
24+
import java.util.List;
2225

23-
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
24-
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
25-
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField;
26-
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
27-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
28-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
26+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
27+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
2928

3029
public class JinaAIRerankResponseEntity {
3130

32-
private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class);
33-
3431
/**
3532
* Parses the JinaAI ranked response.
3633
*
@@ -77,80 +74,76 @@ public class JinaAIRerankResponseEntity {
7774
* @throws IOException if there is an error parsing the response
7875
*/
7976
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
80-
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
81-
82-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
83-
moveToFirstToken(jsonParser);
84-
85-
XContentParser.Token token = jsonParser.currentToken();
86-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
77+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
78+
return Response.PARSER.apply(p, null).toRankedDocsResults();
79+
}
80+
}
8781

88-
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
82+
private record Response(List<ResultItem> results) {
83+
@SuppressWarnings("unchecked")
84+
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
85+
Response.class.getSimpleName(),
86+
true,
87+
args -> new Response((List<ResultItem>) args[0])
88+
);
8989

90-
token = jsonParser.currentToken();
91-
if (token == XContentParser.Token.START_ARRAY) {
92-
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject));
93-
} else {
94-
throwUnknownToken(token, jsonParser);
95-
}
90+
static {
91+
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results"));
92+
}
9693

97-
// This should never be reached. The above code should either return successfully or hit the throwUnknownToken
98-
// or throw a parsing exception
99-
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response");
94+
public RankedDocsResults toRankedDocsResults() {
95+
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream()
96+
.map(item -> new RankedDocsResults.RankedDoc(
97+
item.index(),
98+
item.relevance_score(),
99+
item.document() != null ? item.document().text() : null
100+
))
101+
.toList();
102+
return new RankedDocsResults(rankedDocs);
100103
}
101104
}
102105

103-
private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException {
104-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
105-
int index = -1;
106-
float relevanceScore = -1;
107-
String documentText = null;
108-
parser.nextToken();
109-
while (parser.currentToken() != XContentParser.Token.END_OBJECT) {
110-
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
111-
switch (parser.currentName()) {
112-
case "index":
113-
parser.nextToken(); // move to VALUE_NUMBER
114-
index = parser.intValue();
115-
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
116-
break;
117-
case "relevance_score":
118-
parser.nextToken(); // move to VALUE_NUMBER
119-
relevanceScore = parser.floatValue();
120-
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
121-
break;
122-
case "document":
123-
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object
124-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
125-
do {
126-
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) {
127-
parser.nextToken(); // move to VALUE_STRING
128-
documentText = parser.text();
129-
}
130-
} while (parser.nextToken() != XContentParser.Token.END_OBJECT);
131-
parser.nextToken();// move past END_OBJECT
132-
// parser should now be at the next FIELD_NAME or END_OBJECT
133-
break;
134-
default:
135-
throwUnknownField(parser.currentName(), parser);
136-
}
137-
} else {
138-
parser.nextToken();
139-
}
106+
private record ResultItem(int index, float relevance_score, @Nullable Document document) {
107+
public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>(
108+
ResultItem.class.getSimpleName(),
109+
true,
110+
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2])
111+
);
112+
113+
static {
114+
PARSER.declareInt(constructorArg(), new ParseField("index"));
115+
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
116+
PARSER.declareField(
117+
optionalConstructorArg(),
118+
(p, c) -> parseDocument(p),
119+
new ParseField("document"),
120+
ObjectParser.ValueType.OBJECT_OR_STRING
121+
);
140122
}
123+
}
141124

142-
if (index == -1) {
143-
logger.warn("Failed to find required field [index] in JinaAI rerank response");
144-
}
145-
if (relevanceScore == -1) {
146-
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response");
125+
private record Document(String text) {}
126+
127+
private static Document parseDocument(XContentParser parser) throws IOException {
128+
var token = parser.currentToken();
129+
if (token == XContentParser.Token.START_OBJECT) {
130+
return new Document(DocumentObject.PARSER.apply(parser, null).text());
131+
} else if (token == XContentParser.Token.VALUE_STRING) {
132+
return new Document(parser.text());
147133
}
148-
// documentText may or may not be present depending on the request parameter
149134

150-
return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText);
135+
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token);
151136
}
152137

153-
private JinaAIRerankResponseEntity() {}
138+
private record DocumentObject(String text) {
139+
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>(
140+
DocumentObject.class.getSimpleName(),
141+
true,
142+
args -> new DocumentObject((String) args[0])
143+
);
154144

155-
static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response";
145+
static {
146+
PARSER.declareString(constructorArg(), new ParseField("text"));
147+
}
148+
}
156149
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java

Lines changed: 98 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
package org.elasticsearch.xpack.inference.services.jinaai.response;
99

1010
import org.apache.http.HttpResponse;
11+
import org.elasticsearch.common.Strings;
1112
import org.elasticsearch.inference.InferenceServiceResults;
1213
import org.elasticsearch.test.ESTestCase;
1314
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1415
import org.elasticsearch.xpack.inference.external.http.HttpResult;
15-
import org.hamcrest.MatcherAssert;
1616

1717
import java.io.IOException;
1818
import java.nio.charset.StandardCharsets;
@@ -26,11 +26,33 @@
2626
public class JinaAIRerankResponseEntityTests extends ESTestCase {
2727

2828
public void testResponseLiteral() throws IOException {
29+
String responseLiteral = """
30+
{
31+
"model": "model",
32+
"results": [
33+
{
34+
"index": 2,
35+
"relevance_score": 0.98005307
36+
},
37+
{
38+
"index": 3,
39+
"relevance_score": 0.27904198
40+
},
41+
{
42+
"index": 0,
43+
"relevance_score": 0.10194652
44+
}
45+
],
46+
"usage": {
47+
"total_tokens": 15
48+
}
49+
}
50+
""";
2951
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
3052
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
3153
);
3254

33-
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
55+
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
3456
List<RankedDocsResults.RankedDoc> expected = responseLiteralDocs();
3557
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
3658
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
@@ -68,7 +90,7 @@ public void testGeneratedResponse() throws IOException {
6890
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
6991
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8))
7092
);
71-
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
93+
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
7294
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
7395
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
7496
}
@@ -82,81 +104,92 @@ private ArrayList<RankedDocsResults.RankedDoc> responseLiteralDocs() {
82104
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null));
83105
return list;
84106

85-
};
86-
87-
private final String responseLiteral = """
88-
{
89-
"model": "model",
90-
"results": [
91-
{
92-
"index": 2,
93-
"relevance_score": 0.98005307
94-
},
95-
{
96-
"index": 3,
97-
"relevance_score": 0.27904198
98-
},
99-
{
100-
"index": 0,
101-
"relevance_score": 0.10194652
102-
}
103-
],
104-
"usage": {
105-
"total_tokens": 15
106-
}
107-
}
108-
""";
107+
}
108+
109+
private final String WASHINGTON_TEXT = "Washington, D.C..";
110+
private final String CAPITAL_PUNISHMENT_TEXT =
111+
"Capital punishment has existed in the United States since before the United States was a country. ";
112+
private final String CARSON_CITY_TEXT = "Carson City is the capital city of the American state of Nevada.";
113+
114+
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
115+
new RankedDocsResults.RankedDoc(2, 0.98005307F, WASHINGTON_TEXT),
116+
new RankedDocsResults.RankedDoc(3, 0.27904198F, CAPITAL_PUNISHMENT_TEXT),
117+
new RankedDocsResults.RankedDoc(0, 0.10194652F, CARSON_CITY_TEXT)
118+
);
109119

110120
public void testResponseLiteralWithDocuments() throws IOException {
121+
String responseLiteralWithDocuments = Strings.format("""
122+
{
123+
"model": "model",
124+
"results": [
125+
{
126+
"document": {
127+
"text": "%s"
128+
},
129+
"index": 2,
130+
"relevance_score": 0.98005307
131+
},
132+
{
133+
"document": {
134+
"text": "%s"
135+
},
136+
"index": 3,
137+
"relevance_score": 0.27904198
138+
},
139+
{
140+
"document": {
141+
"text": "%s"
142+
},
143+
"index": 0,
144+
"relevance_score": 0.10194652
145+
}
146+
],
147+
"usage": {
148+
"total_tokens": 15
149+
}
150+
}
151+
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
111152
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
112153
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
113154
);
114155

115-
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
116-
MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
156+
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
157+
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
117158
}
118159

119-
private final String responseLiteralWithDocuments = """
120-
{
121-
"model": "model",
122-
"results": [
123-
{
124-
"document": {
125-
"text": "Washington, D.C.."
126-
},
127-
"index": 2,
128-
"relevance_score": 0.98005307
129-
},
130-
{
131-
"document": {
132-
"text": "Capital punishment has existed in the United States since beforethe United States was a country. "
160+
public void testResponseLiteralWithDocumentsAsString() throws IOException {
161+
String responseLiteralWithDocuments = Strings.format("""
162+
{
163+
"model": "model",
164+
"results": [
165+
{
166+
"document": "%s",
167+
"index": 2,
168+
"relevance_score": 0.98005307
133169
},
134-
"index": 3,
135-
"relevance_score": 0.27904198
136-
},
137-
{
138-
"document": {
139-
"text": "Carson City is the capital city of the American state of Nevada."
170+
{
171+
"document": "%s",
172+
"index": 3,
173+
"relevance_score": 0.27904198
140174
},
141-
"index": 0,
142-
"relevance_score": 0.10194652
175+
{
176+
"document": "%s",
177+
"index": 0,
178+
"relevance_score": 0.10194652
179+
}
180+
],
181+
"usage": {
182+
"total_tokens": 15
143183
}
144-
],
145-
"usage": {
146-
"total_tokens": 15
147184
}
148-
}
149-
""";
185+
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
186+
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
187+
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
188+
);
150189

151-
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
152-
new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."),
153-
new RankedDocsResults.RankedDoc(
154-
3,
155-
0.27904198F,
156-
"Capital punishment has existed in the United States since beforethe United States was a country. "
157-
),
158-
new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.")
159-
);
190+
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
191+
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
192+
}
160193

161194
private ArrayList<Integer> linear(int n) {
162195
ArrayList<Integer> list = new ArrayList<>();

0 commit comments

Comments
 (0)