-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adjust jinaai rerank response parser to handle document field as string or object #136751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
08504a8
53e3bb8
de81a70
c19114f
f5d2196
91e981e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 136751 | ||
summary: Adjust jinaai rerank response parser to handle document field as string or | ||
object | ||
area: Machine Learning | ||
type: bug | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,30 +7,27 @@ | |
|
||
package org.elasticsearch.xpack.inference.services.jinaai.response; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.XContentFactory; | ||
import org.elasticsearch.xcontent.XContentParseException; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; | ||
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
|
||
public class JinaAIRerankResponseEntity { | ||
|
||
private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class); | ||
|
||
/** | ||
* Parses the JinaAI ranked response. | ||
* | ||
|
@@ -77,80 +74,78 @@ public class JinaAIRerankResponseEntity { | |
* @throws IOException if there is an error parsing the response | ||
*/ | ||
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { | ||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); | ||
|
||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { | ||
moveToFirstToken(jsonParser); | ||
|
||
XContentParser.Token token = jsonParser.currentToken(); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); | ||
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { | ||
return Response.PARSER.apply(p, null).toRankedDocsResults(); | ||
} | ||
} | ||
|
||
positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); | ||
private record Response(List<ResultItem> results) { | ||
@SuppressWarnings("unchecked") | ||
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>( | ||
Response.class.getSimpleName(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, typically we just do the |
||
true, | ||
args -> new Response((List<ResultItem>) args[0]) | ||
); | ||
|
||
token = jsonParser.currentToken(); | ||
if (token == XContentParser.Token.START_ARRAY) { | ||
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject)); | ||
} else { | ||
throwUnknownToken(token, jsonParser); | ||
} | ||
static { | ||
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results")); | ||
} | ||
|
||
// This should never be reached. The above code should either return successfully or hit the throwUnknownToken | ||
// or throw a parsing exception | ||
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response"); | ||
public RankedDocsResults toRankedDocsResults() { | ||
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream() | ||
.map( | ||
item -> new RankedDocsResults.RankedDoc( | ||
item.index(), | ||
item.relevance_score(), | ||
item.document() != null ? item.document().text() : null | ||
) | ||
) | ||
.toList(); | ||
return new RankedDocsResults(rankedDocs); | ||
} | ||
} | ||
|
||
private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
int index = -1; | ||
float relevanceScore = -1; | ||
String documentText = null; | ||
parser.nextToken(); | ||
while (parser.currentToken() != XContentParser.Token.END_OBJECT) { | ||
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { | ||
switch (parser.currentName()) { | ||
case "index": | ||
parser.nextToken(); // move to VALUE_NUMBER | ||
index = parser.intValue(); | ||
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT | ||
break; | ||
case "relevance_score": | ||
parser.nextToken(); // move to VALUE_NUMBER | ||
relevanceScore = parser.floatValue(); | ||
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT | ||
break; | ||
case "document": | ||
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
do { | ||
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { | ||
parser.nextToken(); // move to VALUE_STRING | ||
documentText = parser.text(); | ||
} | ||
} while (parser.nextToken() != XContentParser.Token.END_OBJECT); | ||
parser.nextToken();// move past END_OBJECT | ||
// parser should now be at the next FIELD_NAME or END_OBJECT | ||
break; | ||
default: | ||
throwUnknownField(parser.currentName(), parser); | ||
} | ||
} else { | ||
parser.nextToken(); | ||
} | ||
private record ResultItem(int index, float relevance_score, @Nullable Document document) { | ||
|
||
public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>( | ||
ResultItem.class.getSimpleName(), | ||
true, | ||
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2]) | ||
); | ||
|
||
static { | ||
PARSER.declareInt(constructorArg(), new ParseField("index")); | ||
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score")); | ||
PARSER.declareField( | ||
optionalConstructorArg(), | ||
(p, c) -> parseDocument(p), | ||
new ParseField("document"), | ||
ObjectParser.ValueType.OBJECT_OR_STRING | ||
); | ||
} | ||
} | ||
|
||
if (index == -1) { | ||
logger.warn("Failed to find required field [index] in JinaAI rerank response"); | ||
} | ||
if (relevanceScore == -1) { | ||
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response"); | ||
private record Document(String text) {} | ||
|
||
private static Document parseDocument(XContentParser parser) throws IOException { | ||
var token = parser.currentToken(); | ||
if (token == XContentParser.Token.START_OBJECT) { | ||
return new Document(DocumentObject.PARSER.apply(parser, null).text()); | ||
} else if (token == XContentParser.Token.VALUE_STRING) { | ||
return new Document(parser.text()); | ||
} | ||
// documentText may or may not be present depending on the request parameter | ||
|
||
return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); | ||
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token); | ||
} | ||
|
||
private JinaAIRerankResponseEntity() {} | ||
private record DocumentObject(String text) { | ||
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>( | ||
DocumentObject.class.getSimpleName(), | ||
true, | ||
args -> new DocumentObject((String) args[0]) | ||
); | ||
|
||
static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response"; | ||
static { | ||
PARSER.declareString(constructorArg(), new ParseField("text")); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,11 @@ | |
package org.elasticsearch.xpack.inference.services.jinaai.response; | ||
|
||
import org.apache.http.HttpResponse; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.test.ESTestCase; | ||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
import org.hamcrest.MatcherAssert; | ||
|
||
import java.io.IOException; | ||
import java.nio.charset.StandardCharsets; | ||
|
@@ -26,11 +26,33 @@ | |
public class JinaAIRerankResponseEntityTests extends ESTestCase { | ||
|
||
public void testResponseLiteral() throws IOException { | ||
String responseLiteral = """ | ||
{ | ||
"model": "model", | ||
"results": [ | ||
{ | ||
"index": 2, | ||
"relevance_score": 0.98005307 | ||
}, | ||
{ | ||
"index": 3, | ||
"relevance_score": 0.27904198 | ||
}, | ||
{ | ||
"index": 0, | ||
"relevance_score": 0.10194652 | ||
} | ||
], | ||
"usage": { | ||
"total_tokens": 15 | ||
} | ||
} | ||
"""; | ||
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( | ||
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) | ||
); | ||
|
||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
List<RankedDocsResults.RankedDoc> expected = responseLiteralDocs(); | ||
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { | ||
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); | ||
|
@@ -68,7 +90,7 @@ public void testGeneratedResponse() throws IOException { | |
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( | ||
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8)) | ||
); | ||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { | ||
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); | ||
} | ||
|
@@ -82,81 +104,92 @@ private ArrayList<RankedDocsResults.RankedDoc> responseLiteralDocs() { | |
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); | ||
return list; | ||
|
||
}; | ||
|
||
private final String responseLiteral = """ | ||
{ | ||
"model": "model", | ||
"results": [ | ||
{ | ||
"index": 2, | ||
"relevance_score": 0.98005307 | ||
}, | ||
{ | ||
"index": 3, | ||
"relevance_score": 0.27904198 | ||
}, | ||
{ | ||
"index": 0, | ||
"relevance_score": 0.10194652 | ||
} | ||
], | ||
"usage": { | ||
"total_tokens": 15 | ||
} | ||
} | ||
"""; | ||
} | ||
|
||
private final String WASHINGTON_TEXT = "Washington, D.C.."; | ||
private final String CAPITAL_PUNISHMENT_TEXT = | ||
"Capital punishment has existed in the United States since before the United States was a country. "; | ||
private final String CARSON_CITY_TEXT = "Carson City is the capital city of the American state of Nevada."; | ||
|
||
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of( | ||
new RankedDocsResults.RankedDoc(2, 0.98005307F, WASHINGTON_TEXT), | ||
new RankedDocsResults.RankedDoc(3, 0.27904198F, CAPITAL_PUNISHMENT_TEXT), | ||
new RankedDocsResults.RankedDoc(0, 0.10194652F, CARSON_CITY_TEXT) | ||
); | ||
|
||
|
||
public void testResponseLiteralWithDocuments() throws IOException { | ||
String responseLiteralWithDocuments = Strings.format(""" | ||
{ | ||
"model": "model", | ||
"results": [ | ||
{ | ||
"document": { | ||
"text": "%s" | ||
}, | ||
"index": 2, | ||
"relevance_score": 0.98005307 | ||
}, | ||
{ | ||
"document": { | ||
"text": "%s" | ||
}, | ||
"index": 3, | ||
"relevance_score": 0.27904198 | ||
}, | ||
{ | ||
"document": { | ||
"text": "%s" | ||
}, | ||
"index": 0, | ||
"relevance_score": 0.10194652 | ||
} | ||
], | ||
"usage": { | ||
"total_tokens": 15 | ||
} | ||
} | ||
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT); | ||
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( | ||
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) | ||
); | ||
|
||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); | ||
assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); | ||
} | ||
|
||
private final String responseLiteralWithDocuments = """ | ||
{ | ||
"model": "model", | ||
"results": [ | ||
{ | ||
"document": { | ||
"text": "Washington, D.C.." | ||
}, | ||
"index": 2, | ||
"relevance_score": 0.98005307 | ||
}, | ||
{ | ||
"document": { | ||
"text": "Capital punishment has existed in the United States since beforethe United States was a country. " | ||
public void testResponseLiteralWithDocumentsAsString() throws IOException { | ||
String responseLiteralWithDocuments = Strings.format(""" | ||
{ | ||
"model": "model", | ||
"results": [ | ||
{ | ||
"document": "%s", | ||
"index": 2, | ||
"relevance_score": 0.98005307 | ||
}, | ||
"index": 3, | ||
"relevance_score": 0.27904198 | ||
}, | ||
{ | ||
"document": { | ||
"text": "Carson City is the capital city of the American state of Nevada." | ||
{ | ||
"document": "%s", | ||
"index": 3, | ||
"relevance_score": 0.27904198 | ||
}, | ||
"index": 0, | ||
"relevance_score": 0.10194652 | ||
{ | ||
"document": "%s", | ||
"index": 0, | ||
"relevance_score": 0.10194652 | ||
} | ||
], | ||
"usage": { | ||
"total_tokens": 15 | ||
} | ||
], | ||
"usage": { | ||
"total_tokens": 15 | ||
} | ||
} | ||
"""; | ||
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT); | ||
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( | ||
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) | ||
); | ||
|
||
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of( | ||
new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), | ||
new RankedDocsResults.RankedDoc( | ||
3, | ||
0.27904198F, | ||
"Capital punishment has existed in the United States since beforethe United States was a country. " | ||
), | ||
new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.") | ||
); | ||
assertThat(parsedResults, instanceOf(RankedDocsResults.class)); | ||
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); | ||
} | ||
|
||
private ArrayList<Integer> linear(int n) { | ||
ArrayList<Integer> list = new ArrayList<>(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be worth updating the javadoc for this method to include the new behaviour? If you do decide to update it, putting
<pre>
tags around the JSON parts would be helpful, to make the javadoc more readable in the IDE. Right now, when mousing over the method name to show the javadoc, the JSON is formatted all on one line, which is very difficult to read.