From db7fb624c24870fe1d580b4ef71ce0f847ae9113 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:10:08 -0400 Subject: [PATCH] [ML] Adjust jinaai rerank response parser to handle document field as string or object (#136751) * Fixing rerank response parser * Update docs/changelog/136751.yaml * [CI] Auto commit changes from spotless * Addressing feedback * Updating comment about response format --------- Co-authored-by: elasticsearchmachine --- docs/changelog/136751.yaml | 6 + .../response/JinaAIRerankResponseEntity.java | 174 ++++++++++-------- .../JinaAIRerankResponseEntityTests.java | 165 ++++++++++------- 3 files changed, 205 insertions(+), 140 deletions(-) create mode 100644 docs/changelog/136751.yaml diff --git a/docs/changelog/136751.yaml b/docs/changelog/136751.yaml new file mode 100644 index 0000000000000..351814747e9b4 --- /dev/null +++ b/docs/changelog/136751.yaml @@ -0,0 +1,6 @@ +pr: 136751 +summary: Adjust jinaai rerank response parser to handle document field as string or + object +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java index a34d962f8bd2c..3acede30d5cdf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntity.java @@ -7,11 +7,13 @@ package org.elasticsearch.xpack.inference.services.jinaai.response; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; @@ -19,18 +21,13 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; import java.io.IOException; +import java.util.List; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; public class JinaAIRerankResponseEntity { - private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class); - /** * Parses the JinaAI ranked response. * @@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity { * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] *

* The response will look like (without whitespace): + *

      *     {
      *     "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
      *     "results": [
@@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity {
      *     ],
      *     "usage": {"total_tokens": 15}
      *     }
+     *  
+ * + * Or like this where documents is a string: + *
+     *      {
+     *      "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
+     *      "results": [
+     *           {
+     *                "document": "Washington, D.C.  is the capital of the United States. It is a federal district.",
+     *                "index": 2,
+     *                "relevance_score": 0.98005307
+     *           },
+     *           {
+     *                "document":  "abc",
+     *                "index": 3,
+     *                "relevance_score": 0.27904198
+     *           },
+     *           {
+     *                "document": "Carson City is the capital city of the American state of Nevada.",
+     *                "index": 0,
+     *                "relevance_score": 0.10194652
+     *           }
+     *      ],
+     *      "usage": {
+     *           "total_tokens": 15
+     *      }
+     * }
+     *  
+ * + * This parsing logic handles both cases. * * @param response the http response from JinaAI * @return the parsed response * @throws IOException if there is an error parsing the response */ public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - moveToFirstToken(jsonParser); - - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { + return Response.PARSER.apply(p, null).toRankedDocsResults(); + } + } - positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); + private record Response(List results) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Response.class.getSimpleName(), + true, + args -> new Response((List) args[0]) + ); - token = jsonParser.currentToken(); - if (token == XContentParser.Token.START_ARRAY) { - return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject)); - } else { - throwUnknownToken(token, jsonParser); - } + static { + PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results")); + } - // This should never be reached. The above code should either return successfully or hit the throwUnknownToken - // or throw a parsing exception - throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response"); + public RankedDocsResults toRankedDocsResults() { + List rankedDocs = results.stream() + .map( + item -> new RankedDocsResults.RankedDoc( + item.index(), + item.relevanceScore(), + item.document() != null ? item.document().text() : null + ) + ) + .toList(); + return new RankedDocsResults(rankedDocs); } } - private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - float relevanceScore = -1; - String documentText = null; - parser.nextToken(); - while (parser.currentToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); // move to VALUE_NUMBER - index = parser.intValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "relevance_score": - parser.nextToken(); // move to VALUE_NUMBER - relevanceScore = parser.floatValue(); - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT - break; - case "document": - parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - do { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { - parser.nextToken(); // move to VALUE_STRING - documentText = parser.text(); - } - } while (parser.nextToken() != XContentParser.Token.END_OBJECT); - parser.nextToken();// move past END_OBJECT - // parser should now be at the next FIELD_NAME or END_OBJECT - break; - default: - throwUnknownField(parser.currentName(), parser); - } - } else { - parser.nextToken(); - } + private record ResultItem(int index, float relevanceScore, @Nullable Document document) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ResultItem.class.getSimpleName(), + true, + args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2]) + ); + + static { + PARSER.declareInt(constructorArg(), new ParseField("index")); + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseDocument(p), + new ParseField("document"), + ObjectParser.ValueType.OBJECT_OR_STRING + ); } + } - if (index == -1) { - logger.warn("Failed to find required field [index] in JinaAI rerank response"); - } - if (relevanceScore == -1) { - logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response"); + private record Document(String text) {} + + private static Document parseDocument(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return new Document(DocumentObject.PARSER.apply(parser, null).text()); + } else if (token == XContentParser.Token.VALUE_STRING) { + return new Document(parser.text()); } - // documentText may or may not be present depending on the request parameter - return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token); } - private JinaAIRerankResponseEntity() {} + private record DocumentObject(String text) { + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + DocumentObject.class.getSimpleName(), + true, + args -> new DocumentObject((String) args[0]) + ); - static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response"; + static { + PARSER.declareString(constructorArg(), new ParseField("text")); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java index 5019a1808f3cc..96ac390f4cdcc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/response/JinaAIRerankResponseEntityTests.java @@ -8,11 +8,11 @@ package org.elasticsearch.xpack.inference.services.jinaai.response; import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.hamcrest.MatcherAssert; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -25,12 +25,45 @@ public class JinaAIRerankResponseEntityTests extends ESTestCase { + private static final String WASHINGTON_TEXT = "Washington, D.C.."; + private static final String CAPITAL_PUNISHMENT_TEXT = + "Capital punishment has existed in the United States since before the United States was a country. "; + private static final String CARSON_CITY_TEXT = "Carson City is the capital city of the American state of Nevada."; + + private static final List RESPONSE_LITERAL_DOCS_WITH_TEXT = List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307F, WASHINGTON_TEXT), + new RankedDocsResults.RankedDoc(3, 0.27904198F, CAPITAL_PUNISHMENT_TEXT), + new RankedDocsResults.RankedDoc(0, 0.10194652F, CARSON_CITY_TEXT) + ); + public void testResponseLiteral() throws IOException { + String responseLiteral = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); List expected = responseLiteralDocs(); for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); @@ -68,7 +101,7 @@ public void testGeneratedResponse() throws IOException { InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); } @@ -82,81 +115,81 @@ private ArrayList responseLiteralDocs() { list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); return list; - }; - - private final String responseLiteral = """ - { - "model": "model", - "results": [ - { - "index": 2, - "relevance_score": 0.98005307 - }, - { - "index": 3, - "relevance_score": 0.27904198 - }, - { - "index": 0, - "relevance_score": 0.10194652 - } - ], - "usage": { - "total_tokens": 15 - } - } - """; + } public void testResponseLiteralWithDocuments() throws IOException { + String responseLiteralWithDocuments = Strings.format(""" + { + "model": "model", + "results": [ + { + "document": { + "text": "%s" + }, + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": { + "text": "%s" + }, + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": { + "text": "%s" + }, + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """, WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT); InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); - MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS_WITH_TEXT)); } - private final String responseLiteralWithDocuments = """ - { - "model": "model", - "results": [ - { - "document": { - "text": "Washington, D.C.." + public void testResponseLiteralWithDocumentsAsString() throws IOException { + String responseLiteralWithDocuments = Strings.format(""" + { + "model": "model", + "results": [ + { + "document": "%s", + "index": 2, + "relevance_score": 0.98005307 }, - "index": 2, - "relevance_score": 0.98005307 - }, - { - "document": { - "text": "Capital punishment has existed in the United States since beforethe United States was a country. " + { + "document": "%s", + "index": 3, + "relevance_score": 0.27904198 }, - "index": 3, - "relevance_score": 0.27904198 - }, - { - "document": { - "text": "Carson City is the capital city of the American state of Nevada." - }, - "index": 0, - "relevance_score": 0.10194652 + { + "document": "%s", + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 } - ], - "usage": { - "total_tokens": 15 } - } - """; - - private final List responseLiteralDocsWithText = List.of( - new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), - new RankedDocsResults.RankedDoc( - 3, - 0.27904198F, - "Capital punishment has existed in the United States since beforethe United States was a country. " - ), - new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.") - ); + """, WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT); + InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(RESPONSE_LITERAL_DOCS_WITH_TEXT)); + } private ArrayList linear(int n) { ArrayList list = new ArrayList<>();