Skip to content

Commit c28a887

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Adjust jinaai rerank response parser to handle document field as string or object (elastic#136751) (elastic#136765)
* Fixing rerank response parser * Update docs/changelog/136751.yaml * [CI] Auto commit changes from spotless * Addressing feedback * Updating comment about response format --------- (cherry picked from commit a1d7b8a) Co-authored-by: elasticsearchmachine <[email protected]>
1 parent e3c1c3c commit c28a887

File tree

3 files changed

+205
-140
lines changed

3 files changed

+205
-140
lines changed

docs/changelog/136751.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 136751
2+
summary: Adjust jinaai rerank response parser to handle document field as string or
3+
object
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java

Lines changed: 100 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,27 @@
77

88
package org.elasticsearch.xpack.inference.services.jinaai.response;
99

10-
import org.apache.logging.log4j.LogManager;
11-
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
10+
import org.elasticsearch.core.Nullable;
1311
import org.elasticsearch.inference.InferenceServiceResults;
12+
import org.elasticsearch.xcontent.ConstructingObjectParser;
13+
import org.elasticsearch.xcontent.ObjectParser;
14+
import org.elasticsearch.xcontent.ParseField;
1415
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParseException;
1517
import org.elasticsearch.xcontent.XContentParser;
1618
import org.elasticsearch.xcontent.XContentParserConfiguration;
1719
import org.elasticsearch.xcontent.XContentType;
1820
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
1921
import org.elasticsearch.xpack.inference.external.http.HttpResult;
2022

2123
import java.io.IOException;
24+
import java.util.List;
2225

23-
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
24-
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
25-
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField;
26-
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
27-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
28-
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
26+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
27+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
2928

3029
public class JinaAIRerankResponseEntity {
3130

32-
private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class);
33-
3431
/**
3532
* Parses the JinaAI ranked response.
3633
*
@@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity {
4441
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
4542
* <p>
4643
* The response will look like (without whitespace):
44+
* <pre>
4745
* {
4846
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
4947
* "results": [
@@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity {
7169
* ],
7270
* "usage": {"total_tokens": 15}
7371
* }
72+
* </pre>
73+
*
74+
* Or like this where documents is a string:
75+
* <pre>
76+
* {
77+
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
78+
* "results": [
79+
* {
80+
* "document": "Washington, D.C. is the capital of the United States. It is a federal district.",
81+
* "index": 2,
82+
* "relevance_score": 0.98005307
83+
* },
84+
* {
85+
* "document": "abc",
86+
* "index": 3,
87+
* "relevance_score": 0.27904198
88+
* },
89+
* {
90+
* "document": "Carson City is the capital city of the American state of Nevada.",
91+
* "index": 0,
92+
* "relevance_score": 0.10194652
93+
* }
94+
* ],
95+
* "usage": {
96+
* "total_tokens": 15
97+
* }
98+
* }
99+
* </pre>
100+
*
101+
* This parsing logic handles both cases.
74102
*
75103
* @param response the http response from JinaAI
76104
* @return the parsed response
77105
* @throws IOException if there is an error parsing the response
78106
*/
79107
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
80-
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
81-
82-
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
83-
moveToFirstToken(jsonParser);
84-
85-
XContentParser.Token token = jsonParser.currentToken();
86-
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
108+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
109+
return Response.PARSER.apply(p, null).toRankedDocsResults();
110+
}
111+
}
87112

88-
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
113+
private record Response(List<ResultItem> results) {
114+
@SuppressWarnings("unchecked")
115+
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
116+
Response.class.getSimpleName(),
117+
true,
118+
args -> new Response((List<ResultItem>) args[0])
119+
);
89120

90-
token = jsonParser.currentToken();
91-
if (token == XContentParser.Token.START_ARRAY) {
92-
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject));
93-
} else {
94-
throwUnknownToken(token, jsonParser);
95-
}
121+
static {
122+
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results"));
123+
}
96124

97-
// This should never be reached. The above code should either return successfully or hit the throwUnknownToken
98-
// or throw a parsing exception
99-
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response");
125+
public RankedDocsResults toRankedDocsResults() {
126+
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream()
127+
.map(
128+
item -> new RankedDocsResults.RankedDoc(
129+
item.index(),
130+
item.relevanceScore(),
131+
item.document() != null ? item.document().text() : null
132+
)
133+
)
134+
.toList();
135+
return new RankedDocsResults(rankedDocs);
100136
}
101137
}
102138

103-
private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException {
104-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
105-
int index = -1;
106-
float relevanceScore = -1;
107-
String documentText = null;
108-
parser.nextToken();
109-
while (parser.currentToken() != XContentParser.Token.END_OBJECT) {
110-
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
111-
switch (parser.currentName()) {
112-
case "index":
113-
parser.nextToken(); // move to VALUE_NUMBER
114-
index = parser.intValue();
115-
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
116-
break;
117-
case "relevance_score":
118-
parser.nextToken(); // move to VALUE_NUMBER
119-
relevanceScore = parser.floatValue();
120-
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
121-
break;
122-
case "document":
123-
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object
124-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
125-
do {
126-
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) {
127-
parser.nextToken(); // move to VALUE_STRING
128-
documentText = parser.text();
129-
}
130-
} while (parser.nextToken() != XContentParser.Token.END_OBJECT);
131-
parser.nextToken();// move past END_OBJECT
132-
// parser should now be at the next FIELD_NAME or END_OBJECT
133-
break;
134-
default:
135-
throwUnknownField(parser.currentName(), parser);
136-
}
137-
} else {
138-
parser.nextToken();
139-
}
139+
private record ResultItem(int index, float relevanceScore, @Nullable Document document) {
140+
public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>(
141+
ResultItem.class.getSimpleName(),
142+
true,
143+
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2])
144+
);
145+
146+
static {
147+
PARSER.declareInt(constructorArg(), new ParseField("index"));
148+
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
149+
PARSER.declareField(
150+
optionalConstructorArg(),
151+
(p, c) -> parseDocument(p),
152+
new ParseField("document"),
153+
ObjectParser.ValueType.OBJECT_OR_STRING
154+
);
140155
}
156+
}
141157

142-
if (index == -1) {
143-
logger.warn("Failed to find required field [index] in JinaAI rerank response");
144-
}
145-
if (relevanceScore == -1) {
146-
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response");
158+
private record Document(String text) {}
159+
160+
private static Document parseDocument(XContentParser parser) throws IOException {
161+
var token = parser.currentToken();
162+
if (token == XContentParser.Token.START_OBJECT) {
163+
return new Document(DocumentObject.PARSER.apply(parser, null).text());
164+
} else if (token == XContentParser.Token.VALUE_STRING) {
165+
return new Document(parser.text());
147166
}
148-
// documentText may or may not be present depending on the request parameter
149167

150-
return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText);
168+
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token);
151169
}
152170

153-
private JinaAIRerankResponseEntity() {}
171+
private record DocumentObject(String text) {
172+
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>(
173+
DocumentObject.class.getSimpleName(),
174+
true,
175+
args -> new DocumentObject((String) args[0])
176+
);
154177

155-
static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response";
178+
static {
179+
PARSER.declareString(constructorArg(), new ParseField("text"));
180+
}
181+
}
156182
}

0 commit comments

Comments
 (0)