|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.inference.services.jinaai.response; |
9 | 9 |
|
10 | | -import org.apache.logging.log4j.LogManager; |
11 | | -import org.apache.logging.log4j.Logger; |
12 | | -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; |
| 10 | +import org.elasticsearch.core.Nullable; |
13 | 11 | import org.elasticsearch.inference.InferenceServiceResults; |
| 12 | +import org.elasticsearch.xcontent.ConstructingObjectParser; |
| 13 | +import org.elasticsearch.xcontent.ObjectParser; |
| 14 | +import org.elasticsearch.xcontent.ParseField; |
14 | 15 | import org.elasticsearch.xcontent.XContentFactory; |
| 16 | +import org.elasticsearch.xcontent.XContentParseException; |
15 | 17 | import org.elasticsearch.xcontent.XContentParser; |
16 | 18 | import org.elasticsearch.xcontent.XContentParserConfiguration; |
17 | 19 | import org.elasticsearch.xcontent.XContentType; |
18 | 20 | import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; |
19 | 21 | import org.elasticsearch.xpack.inference.external.http.HttpResult; |
20 | 22 |
|
21 | 23 | import java.io.IOException; |
| 24 | +import java.util.List; |
22 | 25 |
|
23 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; |
24 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; |
25 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; |
26 | | -import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; |
27 | | -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; |
28 | | -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; |
| 26 | +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; |
| 27 | +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; |
29 | 28 |
|
30 | 29 | public class JinaAIRerankResponseEntity { |
31 | 30 |
|
32 | | - private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class); |
33 | | - |
34 | 31 | /** |
35 | 32 | * Parses the JinaAI ranked response. |
36 | 33 | * |
@@ -77,80 +74,76 @@ public class JinaAIRerankResponseEntity { |
77 | 74 | * @throws IOException if there is an error parsing the response |
78 | 75 | */ |
79 | 76 | public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { |
80 | | - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); |
81 | | - |
82 | | - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { |
83 | | - moveToFirstToken(jsonParser); |
84 | | - |
85 | | - XContentParser.Token token = jsonParser.currentToken(); |
86 | | - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); |
| 77 | + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { |
| 78 | + return Response.PARSER.apply(p, null).toRankedDocsResults(); |
| 79 | + } |
| 80 | + } |
87 | 81 |
|
88 | | - positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); |
| 82 | + private record Response(List<ResultItem> results) { |
| 83 | + @SuppressWarnings("unchecked") |
| 84 | + public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>( |
| 85 | + Response.class.getSimpleName(), |
| 86 | + true, |
| 87 | + args -> new Response((List<ResultItem>) args[0]) |
| 88 | + ); |
89 | 89 |
|
90 | | - token = jsonParser.currentToken(); |
91 | | - if (token == XContentParser.Token.START_ARRAY) { |
92 | | - return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject)); |
93 | | - } else { |
94 | | - throwUnknownToken(token, jsonParser); |
95 | | - } |
| 90 | + static { |
| 91 | + PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results")); |
| 92 | + } |
96 | 93 |
|
97 | | - // This should never be reached. The above code should either return successfully or hit the throwUnknownToken |
98 | | - // or throw a parsing exception |
99 | | - throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response"); |
| 94 | + public RankedDocsResults toRankedDocsResults() { |
| 95 | + List<RankedDocsResults.RankedDoc> rankedDocs = results.stream() |
| 96 | + .map(item -> new RankedDocsResults.RankedDoc( |
| 97 | + item.index(), |
| 98 | + item.relevance_score(), |
| 99 | + item.document() != null ? item.document().text() : null |
| 100 | + )) |
| 101 | + .toList(); |
| 102 | + return new RankedDocsResults(rankedDocs); |
100 | 103 | } |
101 | 104 | } |
102 | 105 |
|
103 | | - private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { |
104 | | - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); |
105 | | - int index = -1; |
106 | | - float relevanceScore = -1; |
107 | | - String documentText = null; |
108 | | - parser.nextToken(); |
109 | | - while (parser.currentToken() != XContentParser.Token.END_OBJECT) { |
110 | | - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { |
111 | | - switch (parser.currentName()) { |
112 | | - case "index": |
113 | | - parser.nextToken(); // move to VALUE_NUMBER |
114 | | - index = parser.intValue(); |
115 | | - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT |
116 | | - break; |
117 | | - case "relevance_score": |
118 | | - parser.nextToken(); // move to VALUE_NUMBER |
119 | | - relevanceScore = parser.floatValue(); |
120 | | - parser.nextToken(); // move to next FIELD_NAME or END_OBJECT |
121 | | - break; |
122 | | - case "document": |
123 | | - parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object |
124 | | - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); |
125 | | - do { |
126 | | - if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { |
127 | | - parser.nextToken(); // move to VALUE_STRING |
128 | | - documentText = parser.text(); |
129 | | - } |
130 | | - } while (parser.nextToken() != XContentParser.Token.END_OBJECT); |
131 | | - parser.nextToken();// move past END_OBJECT |
132 | | - // parser should now be at the next FIELD_NAME or END_OBJECT |
133 | | - break; |
134 | | - default: |
135 | | - throwUnknownField(parser.currentName(), parser); |
136 | | - } |
137 | | - } else { |
138 | | - parser.nextToken(); |
139 | | - } |
| 106 | + private record ResultItem(int index, float relevance_score, @Nullable Document document) { |
| 107 | + public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>( |
| 108 | + ResultItem.class.getSimpleName(), |
| 109 | + true, |
| 110 | + args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2]) |
| 111 | + ); |
| 112 | + |
| 113 | + static { |
| 114 | + PARSER.declareInt(constructorArg(), new ParseField("index")); |
| 115 | + PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); |
| 116 | + PARSER.declareField( |
| 117 | + optionalConstructorArg(), |
| 118 | + (p, c) -> parseDocument(p), |
| 119 | + new ParseField("document"), |
| 120 | + ObjectParser.ValueType.OBJECT_OR_STRING |
| 121 | + ); |
140 | 122 | } |
| 123 | + } |
141 | 124 |
|
142 | | - if (index == -1) { |
143 | | - logger.warn("Failed to find required field [index] in JinaAI rerank response"); |
144 | | - } |
145 | | - if (relevanceScore == -1) { |
146 | | - logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response"); |
| 125 | + private record Document(String text) {} |
| 126 | + |
| 127 | + private static Document parseDocument(XContentParser parser) throws IOException { |
| 128 | + var token = parser.currentToken(); |
| 129 | + if (token == XContentParser.Token.START_OBJECT) { |
| 130 | + return new Document(DocumentObject.PARSER.apply(parser, null).text()); |
| 131 | + } else if (token == XContentParser.Token.VALUE_STRING) { |
| 132 | + return new Document(parser.text()); |
147 | 133 | } |
148 | | - // documentText may or may not be present depending on the request parameter |
149 | 134 |
|
150 | | - return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); |
| 135 | + throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token); |
151 | 136 | } |
152 | 137 |
|
153 | | - private JinaAIRerankResponseEntity() {} |
| 138 | + private record DocumentObject(String text) { |
| 139 | + public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>( |
| 140 | + DocumentObject.class.getSimpleName(), |
| 141 | + true, |
| 142 | + args -> new DocumentObject((String) args[0]) |
| 143 | + ); |
154 | 144 |
|
155 | | - static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response"; |
| 145 | + static { |
| 146 | + PARSER.declareString(constructorArg(), new ParseField("text")); |
| 147 | + } |
| 148 | + } |
156 | 149 | } |
0 commit comments