77
88package org .elasticsearch .xpack .inference .services .jinaai .response ;
99
10- import org .apache .logging .log4j .LogManager ;
11- import org .apache .logging .log4j .Logger ;
12- import org .elasticsearch .common .xcontent .LoggingDeprecationHandler ;
10+ import org .elasticsearch .core .Nullable ;
1311import org .elasticsearch .inference .InferenceServiceResults ;
12+ import org .elasticsearch .xcontent .ConstructingObjectParser ;
13+ import org .elasticsearch .xcontent .ObjectParser ;
14+ import org .elasticsearch .xcontent .ParseField ;
1415import org .elasticsearch .xcontent .XContentFactory ;
16+ import org .elasticsearch .xcontent .XContentParseException ;
1517import org .elasticsearch .xcontent .XContentParser ;
1618import org .elasticsearch .xcontent .XContentParserConfiguration ;
1719import org .elasticsearch .xcontent .XContentType ;
1820import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
1921import org .elasticsearch .xpack .inference .external .http .HttpResult ;
2022
2123import java .io .IOException ;
24+ import java .util .List ;
2225
23- import static org .elasticsearch .common .xcontent .XContentParserUtils .ensureExpectedToken ;
24- import static org .elasticsearch .common .xcontent .XContentParserUtils .parseList ;
25- import static org .elasticsearch .common .xcontent .XContentParserUtils .throwUnknownField ;
26- import static org .elasticsearch .common .xcontent .XContentParserUtils .throwUnknownToken ;
27- import static org .elasticsearch .xpack .inference .external .response .XContentUtils .moveToFirstToken ;
28- import static org .elasticsearch .xpack .inference .external .response .XContentUtils .positionParserAtTokenAfterField ;
26+ import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
27+ import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
2928
3029public class JinaAIRerankResponseEntity {
3130
32- private static final Logger logger = LogManager .getLogger (JinaAIRerankResponseEntity .class );
33-
3431 /**
3532 * Parses the JinaAI ranked response.
3633 *
@@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity {
4441 * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
4542 * <p>
4643 * The response will look like (without whitespace):
44+ * <pre>
4745 * {
4846 * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
4947 * "results": [
@@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity {
7169 * ],
7270 * "usage": {"total_tokens": 15}
7371 * }
72+ * </pre>
73+ *
74+ * Or like this where documents is a string:
75+ * <pre>
76+ * {
77+ * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
78+ * "results": [
79+ * {
80+ * "document": "Washington, D.C. is the capital of the United States. It is a federal district.",
81+ * "index": 2,
82+ * "relevance_score": 0.98005307
83+ * },
84+ * {
85+ * "document": "abc",
86+ * "index": 3,
87+ * "relevance_score": 0.27904198
88+ * },
89+ * {
90+ * "document": "Carson City is the capital city of the American state of Nevada.",
91+ * "index": 0,
92+ * "relevance_score": 0.10194652
93+ * }
94+ * ],
95+ * "usage": {
96+ * "total_tokens": 15
97+ * }
98+ * }
99+ * </pre>
100+ *
101+ * This parsing logic handles both cases.
74102 *
75103 * @param response the http response from JinaAI
76104 * @return the parsed response
77105 * @throws IOException if there is an error parsing the response
78106 */
79107 public static InferenceServiceResults fromResponse (HttpResult response ) throws IOException {
80- var parserConfig = XContentParserConfiguration .EMPTY .withDeprecationHandler (LoggingDeprecationHandler .INSTANCE );
81-
82- try (XContentParser jsonParser = XContentFactory .xContent (XContentType .JSON ).createParser (parserConfig , response .body ())) {
83- moveToFirstToken (jsonParser );
84-
85- XContentParser .Token token = jsonParser .currentToken ();
86- ensureExpectedToken (XContentParser .Token .START_OBJECT , token , jsonParser );
108+ try (var p = XContentFactory .xContent (XContentType .JSON ).createParser (XContentParserConfiguration .EMPTY , response .body ())) {
109+ return Response .PARSER .apply (p , null ).toRankedDocsResults ();
110+ }
111+ }
87112
88- positionParserAtTokenAfterField (jsonParser , "results" , FAILED_TO_FIND_FIELD_TEMPLATE );
113+ private record Response (List <ResultItem > results ) {
114+ @ SuppressWarnings ("unchecked" )
115+ public static final ConstructingObjectParser <Response , Void > PARSER = new ConstructingObjectParser <>(
116+ Response .class .getSimpleName (),
117+ true ,
118+ args -> new Response ((List <ResultItem >) args [0 ])
119+ );
89120
90- token = jsonParser .currentToken ();
91- if (token == XContentParser .Token .START_ARRAY ) {
92- return new RankedDocsResults (parseList (jsonParser , JinaAIRerankResponseEntity ::parseRankedDocObject ));
93- } else {
94- throwUnknownToken (token , jsonParser );
95- }
121+ static {
122+ PARSER .declareObjectArray (constructorArg (), ResultItem .PARSER ::apply , new ParseField ("results" ));
123+ }
96124
97- // This should never be reached. The above code should either return successfully or hit the throwUnknownToken
98- // or throw a parsing exception
99- throw new IllegalStateException ("Reached an invalid state while parsing the JinaAI response" );
125+ public RankedDocsResults toRankedDocsResults () {
126+ List <RankedDocsResults .RankedDoc > rankedDocs = results .stream ()
127+ .map (
128+ item -> new RankedDocsResults .RankedDoc (
129+ item .index (),
130+ item .relevanceScore (),
131+ item .document () != null ? item .document ().text () : null
132+ )
133+ )
134+ .toList ();
135+ return new RankedDocsResults (rankedDocs );
100136 }
101137 }
102138
103- private static RankedDocsResults .RankedDoc parseRankedDocObject (XContentParser parser ) throws IOException {
104- ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
105- int index = -1 ;
106- float relevanceScore = -1 ;
107- String documentText = null ;
108- parser .nextToken ();
109- while (parser .currentToken () != XContentParser .Token .END_OBJECT ) {
110- if (parser .currentToken () == XContentParser .Token .FIELD_NAME ) {
111- switch (parser .currentName ()) {
112- case "index" :
113- parser .nextToken (); // move to VALUE_NUMBER
114- index = parser .intValue ();
115- parser .nextToken (); // move to next FIELD_NAME or END_OBJECT
116- break ;
117- case "relevance_score" :
118- parser .nextToken (); // move to VALUE_NUMBER
119- relevanceScore = parser .floatValue ();
120- parser .nextToken (); // move to next FIELD_NAME or END_OBJECT
121- break ;
122- case "document" :
123- parser .nextToken (); // move to START_OBJECT; document text is wrapped in an object
124- ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
125- do {
126- if (parser .currentToken () == XContentParser .Token .FIELD_NAME && parser .currentName ().equals ("text" )) {
127- parser .nextToken (); // move to VALUE_STRING
128- documentText = parser .text ();
129- }
130- } while (parser .nextToken () != XContentParser .Token .END_OBJECT );
131- parser .nextToken ();// move past END_OBJECT
132- // parser should now be at the next FIELD_NAME or END_OBJECT
133- break ;
134- default :
135- throwUnknownField (parser .currentName (), parser );
136- }
137- } else {
138- parser .nextToken ();
139- }
139+ private record ResultItem (int index , float relevanceScore , @ Nullable Document document ) {
140+ public static final ConstructingObjectParser <ResultItem , Void > PARSER = new ConstructingObjectParser <>(
141+ ResultItem .class .getSimpleName (),
142+ true ,
143+ args -> new ResultItem ((Integer ) args [0 ], (Float ) args [1 ], (Document ) args [2 ])
144+ );
145+
146+ static {
147+ PARSER .declareInt (constructorArg (), new ParseField ("index" ));
148+ PARSER .declareFloat (constructorArg (), new ParseField ("relevance_score" ));
149+ PARSER .declareField (
150+ optionalConstructorArg (),
151+ (p , c ) -> parseDocument (p ),
152+ new ParseField ("document" ),
153+ ObjectParser .ValueType .OBJECT_OR_STRING
154+ );
140155 }
156+ }
141157
142- if (index == -1 ) {
143- logger .warn ("Failed to find required field [index] in JinaAI rerank response" );
144- }
145- if (relevanceScore == -1 ) {
146- logger .warn ("Failed to find required field [relevance_score] in JinaAI rerank response" );
158+ private record Document (String text ) {}
159+
160+ private static Document parseDocument (XContentParser parser ) throws IOException {
161+ var token = parser .currentToken ();
162+ if (token == XContentParser .Token .START_OBJECT ) {
163+ return new Document (DocumentObject .PARSER .apply (parser , null ).text ());
164+ } else if (token == XContentParser .Token .VALUE_STRING ) {
165+ return new Document (parser .text ());
147166 }
148- // documentText may or may not be present depending on the request parameter
149167
150- return new RankedDocsResults . RankedDoc ( index , relevanceScore , documentText );
168+ throw new XContentParseException ( parser . getTokenLocation (), "Expected an object or string for document field, but got: " + token );
151169 }
152170
153- private JinaAIRerankResponseEntity () {}
171+ private record DocumentObject (String text ) {
172+ public static final ConstructingObjectParser <DocumentObject , Void > PARSER = new ConstructingObjectParser <>(
173+ DocumentObject .class .getSimpleName (),
174+ true ,
175+ args -> new DocumentObject ((String ) args [0 ])
176+ );
154177
155- static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response" ;
178+ static {
179+ PARSER .declareString (constructorArg (), new ParseField ("text" ));
180+ }
181+ }
156182}
0 commit comments