diff --git a/docs/changelog/136751.yaml b/docs/changelog/136751.yaml
new file mode 100644
index 0000000000000..351814747e9b4
--- /dev/null
+++ b/docs/changelog/136751.yaml
@@ -0,0 +1,6 @@
+pr: 136751
+summary: Adjust jinaai rerank response parser to handle document field as string or
+ object
+area: Machine Learning
+type: bug
+issues: []
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java
index a34d962f8bd2c..3acede30d5cdf 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java
@@ -7,11 +7,13 @@
package org.elasticsearch.xpack.inference.services.jinaai.response;
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
@@ -19,18 +21,13 @@
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import java.io.IOException;
+import java.util.List;
-import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
-import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
-import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField;
-import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
-import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
-import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class JinaAIRerankResponseEntity {
- private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class);
-
/**
* Parses the JinaAI ranked response.
*
@@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity {
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
*
* The response will look like (without whitespace):
+ *
* {
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
* "results": [
@@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity {
* ],
* "usage": {"total_tokens": 15}
* }
+ *
+ *
+ * Or like this where documents is a string:
+ *
+ * {
+ * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
+ * "results": [
+ * {
+ * "document": "Washington, D.C. is the capital of the United States. It is a federal district.",
+ * "index": 2,
+ * "relevance_score": 0.98005307
+ * },
+ * {
+ * "document": "abc",
+ * "index": 3,
+ * "relevance_score": 0.27904198
+ * },
+ * {
+ * "document": "Carson City is the capital city of the American state of Nevada.",
+ * "index": 0,
+ * "relevance_score": 0.10194652
+ * }
+ * ],
+ * "usage": {
+ * "total_tokens": 15
+ * }
+ * }
+ *
+ *
+ * This parsing logic handles both cases.
*
* @param response the http response from JinaAI
* @return the parsed response
* @throws IOException if there is an error parsing the response
*/
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
- var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
-
- try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
- moveToFirstToken(jsonParser);
-
- XContentParser.Token token = jsonParser.currentToken();
- ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+ try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
+ return Response.PARSER.apply(p, null).toRankedDocsResults();
+ }
+ }
- positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
+ private record Response(List results) {
+ @SuppressWarnings("unchecked")
+ public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ Response.class.getSimpleName(),
+ true,
+ args -> new Response((List) args[0])
+ );
- token = jsonParser.currentToken();
- if (token == XContentParser.Token.START_ARRAY) {
- return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject));
- } else {
- throwUnknownToken(token, jsonParser);
- }
+ static {
+ PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results"));
+ }
- // This should never be reached. The above code should either return successfully or hit the throwUnknownToken
- // or throw a parsing exception
- throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response");
+ public RankedDocsResults toRankedDocsResults() {
+ List rankedDocs = results.stream()
+ .map(
+ item -> new RankedDocsResults.RankedDoc(
+ item.index(),
+ item.relevanceScore(),
+ item.document() != null ? item.document().text() : null
+ )
+ )
+ .toList();
+ return new RankedDocsResults(rankedDocs);
}
}
- private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException {
- ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
- int index = -1;
- float relevanceScore = -1;
- String documentText = null;
- parser.nextToken();
- while (parser.currentToken() != XContentParser.Token.END_OBJECT) {
- if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
- switch (parser.currentName()) {
- case "index":
- parser.nextToken(); // move to VALUE_NUMBER
- index = parser.intValue();
- parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
- break;
- case "relevance_score":
- parser.nextToken(); // move to VALUE_NUMBER
- relevanceScore = parser.floatValue();
- parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
- break;
- case "document":
- parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object
- ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
- do {
- if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) {
- parser.nextToken(); // move to VALUE_STRING
- documentText = parser.text();
- }
- } while (parser.nextToken() != XContentParser.Token.END_OBJECT);
- parser.nextToken();// move past END_OBJECT
- // parser should now be at the next FIELD_NAME or END_OBJECT
- break;
- default:
- throwUnknownField(parser.currentName(), parser);
- }
- } else {
- parser.nextToken();
- }
+ private record ResultItem(int index, float relevanceScore, @Nullable Document document) {
+ public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ ResultItem.class.getSimpleName(),
+ true,
+ args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2])
+ );
+
+ static {
+ PARSER.declareInt(constructorArg(), new ParseField("index"));
+ PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
+ PARSER.declareField(
+ optionalConstructorArg(),
+ (p, c) -> parseDocument(p),
+ new ParseField("document"),
+ ObjectParser.ValueType.OBJECT_OR_STRING
+ );
}
+ }
- if (index == -1) {
- logger.warn("Failed to find required field [index] in JinaAI rerank response");
- }
- if (relevanceScore == -1) {
- logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response");
+ private record Document(String text) {}
+
+ private static Document parseDocument(XContentParser parser) throws IOException {
+ var token = parser.currentToken();
+ if (token == XContentParser.Token.START_OBJECT) {
+ return new Document(DocumentObject.PARSER.apply(parser, null).text());
+ } else if (token == XContentParser.Token.VALUE_STRING) {
+ return new Document(parser.text());
}
- // documentText may or may not be present depending on the request parameter
- return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText);
+ throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token);
}
- private JinaAIRerankResponseEntity() {}
+ private record DocumentObject(String text) {
+ public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ DocumentObject.class.getSimpleName(),
+ true,
+ args -> new DocumentObject((String) args[0])
+ );
- static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response";
+ static {
+ PARSER.declareString(constructorArg(), new ParseField("text"));
+ }
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java
index 5019a1808f3cc..96ac390f4cdcc 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java
@@ -8,11 +8,11 @@
package org.elasticsearch.xpack.inference.services.jinaai.response;
import org.apache.http.HttpResponse;
+import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
-import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@@ -25,12 +25,45 @@
public class JinaAIRerankResponseEntityTests extends ESTestCase {
+ private static final String WASHINGTON_TEXT = "Washington, D.C..";
+ private static final String CAPITAL_PUNISHMENT_TEXT =
+ "Capital punishment has existed in the United States since before the United States was a country. ";
+ private static final String CARSON_CITY_TEXT = "Carson City is the capital city of the American state of Nevada.";
+
+ private static final List RESPONSE_LITERAL_DOCS_WITH_TEXT = List.of(
+ new RankedDocsResults.RankedDoc(2, 0.98005307F, WASHINGTON_TEXT),
+ new RankedDocsResults.RankedDoc(3, 0.27904198F, CAPITAL_PUNISHMENT_TEXT),
+ new RankedDocsResults.RankedDoc(0, 0.10194652F, CARSON_CITY_TEXT)
+ );
+
public void testResponseLiteral() throws IOException {
+ String responseLiteral = """
+ {
+ "model": "model",
+ "results": [
+ {
+ "index": 2,
+ "relevance_score": 0.98005307
+ },
+ {
+ "index": 3,
+ "relevance_score": 0.27904198
+ },
+ {
+ "index": 0,
+ "relevance_score": 0.10194652
+ }
+ ],
+ "usage": {
+ "total_tokens": 15
+ }
+ }
+ """;
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
);
- MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
+ assertThat(parsedResults, instanceOf(RankedDocsResults.class));
List expected = responseLiteralDocs();
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
@@ -68,7 +101,7 @@ public void testGeneratedResponse() throws IOException {
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8))
);
- MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
+ assertThat(parsedResults, instanceOf(RankedDocsResults.class));
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
}
@@ -82,81 +115,81 @@ private ArrayList responseLiteralDocs() {
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null));
return list;
- };
-
- private final String responseLiteral = """
- {
- "model": "model",
- "results": [
- {
- "index": 2,
- "relevance_score": 0.98005307
- },
- {
- "index": 3,
- "relevance_score": 0.27904198
- },
- {
- "index": 0,
- "relevance_score": 0.10194652
- }
- ],
- "usage": {
- "total_tokens": 15
- }
- }
- """;
+ }
public void testResponseLiteralWithDocuments() throws IOException {
+ String responseLiteralWithDocuments = Strings.format("""
+ {
+ "model": "model",
+ "results": [
+ {
+ "document": {
+ "text": "%s"
+ },
+ "index": 2,
+ "relevance_score": 0.98005307
+ },
+ {
+ "document": {
+ "text": "%s"
+ },
+ "index": 3,
+ "relevance_score": 0.27904198
+ },
+ {
+ "document": {
+ "text": "%s"
+ },
+ "index": 0,
+ "relevance_score": 0.10194652
+ }
+ ],
+ "usage": {
+ "total_tokens": 15
+ }
+ }
+ """, WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
);
- MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
- MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
+ assertThat(parsedResults, instanceOf(RankedDocsResults.class));
+ assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS_WITH_TEXT));
}
- private final String responseLiteralWithDocuments = """
- {
- "model": "model",
- "results": [
- {
- "document": {
- "text": "Washington, D.C.."
+ public void testResponseLiteralWithDocumentsAsString() throws IOException {
+ String responseLiteralWithDocuments = Strings.format("""
+ {
+ "model": "model",
+ "results": [
+ {
+ "document": "%s",
+ "index": 2,
+ "relevance_score": 0.98005307
},
- "index": 2,
- "relevance_score": 0.98005307
- },
- {
- "document": {
- "text": "Capital punishment has existed in the United States since beforethe United States was a country. "
+ {
+ "document": "%s",
+ "index": 3,
+ "relevance_score": 0.27904198
},
- "index": 3,
- "relevance_score": 0.27904198
- },
- {
- "document": {
- "text": "Carson City is the capital city of the American state of Nevada."
- },
- "index": 0,
- "relevance_score": 0.10194652
+ {
+ "document": "%s",
+ "index": 0,
+ "relevance_score": 0.10194652
+ }
+ ],
+ "usage": {
+ "total_tokens": 15
}
- ],
- "usage": {
- "total_tokens": 15
}
- }
- """;
-
- private final List responseLiteralDocsWithText = List.of(
- new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."),
- new RankedDocsResults.RankedDoc(
- 3,
- 0.27904198F,
- "Capital punishment has existed in the United States since beforethe United States was a country. "
- ),
- new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.")
- );
+ """, WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
+ InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
+ new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(parsedResults, instanceOf(RankedDocsResults.class));
+ assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS_WITH_TEXT));
+ }
private ArrayList linear(int n) {
ArrayList list = new ArrayList<>();