-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adjust jinaai rerank response parser to handle document field as string or object #136751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
08504a8
53e3bb8
de81a70
c19114f
f5d2196
91e981e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 136751 | ||
summary: Adjust jinaai rerank response parser to handle document field as string or | ||
object | ||
area: Machine Learning | ||
type: bug | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,30 +7,27 @@ | |
|
||
package org.elasticsearch.xpack.inference.services.jinaai.response; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.XContentFactory; | ||
import org.elasticsearch.xcontent.XContentParseException; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
|
||
public class JinaAIRerankResponseEntity { | ||
|
||
private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class); | ||
|
||
/** | ||
* Parses the JinaAI ranked response. | ||
* | ||
|
@@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity { | |
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] | ||
* <p> | ||
* The response will look like (without whitespace): | ||
* <pre> | ||
* { | ||
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", | ||
* "results": [ | ||
|
@@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity { | |
* ], | ||
* "usage": {"total_tokens": 15} | ||
* } | ||
* </pre> | ||
* | ||
* Or like this where documents is a string: | ||
* <pre> | ||
* { | ||
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", | ||
* "results": [ | ||
* { | ||
* "document": "Washington, D.C. is the capital of the United States. It is a federal district.", | ||
* "index": 2, | ||
* "relevance_score": 0.98005307 | ||
* }, | ||
* { | ||
* "document": "abc", | ||
* "index": 3, | ||
* "relevance_score": 0.27904198 | ||
* }, | ||
* { | ||
* "document": "Carson City is the capital city of the American state of Nevada.", | ||
* "index": 0, | ||
* "relevance_score": 0.10194652 | ||
* } | ||
* ], | ||
* "usage": { | ||
* "total_tokens": 15 | ||
* } | ||
* } | ||
* </pre> | ||
* | ||
* This parsing logic handles both cases. | ||
* | ||
* @param response the http response from JinaAI | ||
* @return the parsed response | ||
* @throws IOException if there is an error parsing the response | ||
*/ | ||
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { | ||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); | ||
|
||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { | ||
moveToFirstToken(jsonParser); | ||
|
||
XContentParser.Token token = jsonParser.currentToken(); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); | ||
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { | ||
return Response.PARSER.apply(p, null).toRankedDocsResults(); | ||
} | ||
} | ||
|
||
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); | ||
private record Response(List<ResultItem> results) { | ||
@SuppressWarnings("unchecked") | ||
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>( | ||
Response.class.getSimpleName(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, typically we just do the |
||
true, | ||
args -> new Response((List<ResultItem>) args[0]) | ||
); | ||
|
||
token = jsonParser.currentToken(); | ||
if (token == XContentParser.Token.START_ARRAY) { | ||
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject)); | ||
} else { | ||
throwUnknownToken(token, jsonParser); | ||
} | ||
static { | ||
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results")); | ||
} | ||
|
||
// This should never be reached. The above code should either return successfully or hit the throwUnknownToken | ||
// or throw a parsing exception | ||
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response"); | ||
public RankedDocsResults toRankedDocsResults() { | ||
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream() | ||
.map( | ||
item -> new RankedDocsResults.RankedDoc( | ||
item.index(), | ||
item.relevanceScore(), | ||
item.document() != null ? item.document().text() : null | ||
) | ||
) | ||
.toList(); | ||
return new RankedDocsResults(rankedDocs); | ||
} | ||
} | ||
|
||
private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
int index = -1; | ||
float relevanceScore = -1; | ||
String documentText = null; | ||
parser.nextToken(); | ||
while (parser.currentToken() != XContentParser.Token.END_OBJECT) { | ||
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { | ||
switch (parser.currentName()) { | ||
case "index": | ||
parser.nextToken(); // move to VALUE_NUMBER | ||
index = parser.intValue(); | ||
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT | ||
break; | ||
case "relevance_score": | ||
parser.nextToken(); // move to VALUE_NUMBER | ||
relevanceScore = parser.floatValue(); | ||
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT | ||
break; | ||
case "document": | ||
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
do { | ||
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { | ||
parser.nextToken(); // move to VALUE_STRING | ||
documentText = parser.text(); | ||
} | ||
} while (parser.nextToken() != XContentParser.Token.END_OBJECT); | ||
parser.nextToken();// move past END_OBJECT | ||
// parser should now be at the next FIELD_NAME or END_OBJECT | ||
break; | ||
default: | ||
throwUnknownField(parser.currentName(), parser); | ||
} | ||
} else { | ||
parser.nextToken(); | ||
} | ||
private record ResultItem(int index, float relevanceScore, @Nullable Document document) { | ||
public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>( | ||
ResultItem.class.getSimpleName(), | ||
true, | ||
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2]) | ||
); | ||
|
||
static { | ||
PARSER.declareInt(constructorArg(), new ParseField("index")); | ||
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); | ||
PARSER.declareField( | ||
optionalConstructorArg(), | ||
(p, c) -> parseDocument(p), | ||
new ParseField("document"), | ||
ObjectParser.ValueType.OBJECT_OR_STRING | ||
); | ||
} | ||
} | ||
|
||
if (index == -1) { | ||
logger.warn("Failed to find required field [index] in JinaAI rerank response"); | ||
} | ||
if (relevanceScore == -1) { | ||
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response"); | ||
private record Document(String text) {} | ||
|
||
private static Document parseDocument(XContentParser parser) throws IOException { | ||
var token = parser.currentToken(); | ||
if (token == XContentParser.Token.START_OBJECT) { | ||
return new Document(DocumentObject.PARSER.apply(parser, null).text()); | ||
} else if (token == XContentParser.Token.VALUE_STRING) { | ||
return new Document(parser.text()); | ||
} | ||
// documentText may or may not be present depending on the request parameter | ||
|
||
return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); | ||
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token); | ||
} | ||
|
||
private JinaAIRerankResponseEntity() {} | ||
private record DocumentObject(String text) { | ||
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>( | ||
DocumentObject.class.getSimpleName(), | ||
true, | ||
args -> new DocumentObject((String) args[0]) | ||
); | ||
|
||
static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response"; | ||
static { | ||
PARSER.declareString(constructorArg(), new ParseField("text")); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be worth updating the javadoc for this method to include the new behaviour? If you do decide to update it, putting
<pre>
tags around the JSON parts would be helpful, to make the javadoc more readable in the IDE. Right now, when mousing over the method name to show the javadoc, the JSON is formatted all on one line, which is very difficult to read.