Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/136751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 136751
summary: Adjust jinaai rerank response parser to handle document field as string or
object
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,27 @@

package org.elasticsearch.xpack.inference.services.jinaai.response;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField;
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class JinaAIRerankResponseEntity {

private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class);

/**
* Parses the JinaAI ranked response.
*
Expand Down Expand Up @@ -77,80 +74,78 @@ public class JinaAIRerankResponseEntity {
* @throws IOException if there is an error parsing the response
*/
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth updating the javadoc for this method to include the new behaviour? If you do decide to update it, putting <pre> tags around the JSON parts would be helpful, to make the javadoc more readable in the IDE. Right now, when mousing over the method name to show the javadoc, the JSON is formatted all on one line, which is very difficult to read.

var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return Response.PARSER.apply(p, null).toRankedDocsResults();
}
}

positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
private record Response(List<ResultItem> results) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
Response.class.getSimpleName(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using getSimpleName() here will result in any errors during parsing referencing Response which is pretty vague if we need to use it for debugging. Using any of the other "name" methods would be too verbose, but is there some way we could include the name of the parent class? Or is that overkill?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, typically we just do the getSimpleName(). I think the error will have the stacktrace though which should get us to the code that is failing 🤔

true,
args -> new Response((List<ResultItem>) args[0])
);

token = jsonParser.currentToken();
if (token == XContentParser.Token.START_ARRAY) {
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject));
} else {
throwUnknownToken(token, jsonParser);
}
static {
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results"));
}

// This should never be reached. The above code should either return successfully or hit the throwUnknownToken
// or throw a parsing exception
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response");
public RankedDocsResults toRankedDocsResults() {
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream()
.map(
item -> new RankedDocsResults.RankedDoc(
item.index(),
item.relevance_score(),
item.document() != null ? item.document().text() : null
)
)
.toList();
return new RankedDocsResults(rankedDocs);
}
}

private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
int index = -1;
float relevanceScore = -1;
String documentText = null;
parser.nextToken();
while (parser.currentToken() != XContentParser.Token.END_OBJECT) {
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
switch (parser.currentName()) {
case "index":
parser.nextToken(); // move to VALUE_NUMBER
index = parser.intValue();
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
break;
case "relevance_score":
parser.nextToken(); // move to VALUE_NUMBER
relevanceScore = parser.floatValue();
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
break;
case "document":
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
do {
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) {
parser.nextToken(); // move to VALUE_STRING
documentText = parser.text();
}
} while (parser.nextToken() != XContentParser.Token.END_OBJECT);
parser.nextToken();// move past END_OBJECT
// parser should now be at the next FIELD_NAME or END_OBJECT
break;
default:
throwUnknownField(parser.currentName(), parser);
}
} else {
parser.nextToken();
}
private record ResultItem(int index, float relevance_score, @Nullable Document document) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, but relevance_score should probably be relevanceScore for code style consistency.

public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>(
ResultItem.class.getSimpleName(),
true,
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> parseDocument(p),
new ParseField("document"),
ObjectParser.ValueType.OBJECT_OR_STRING
);
}
}

if (index == -1) {
logger.warn("Failed to find required field [index] in JinaAI rerank response");
}
if (relevanceScore == -1) {
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response");
private record Document(String text) {}

private static Document parseDocument(XContentParser parser) throws IOException {
var token = parser.currentToken();
if (token == XContentParser.Token.START_OBJECT) {
return new Document(DocumentObject.PARSER.apply(parser, null).text());
} else if (token == XContentParser.Token.VALUE_STRING) {
return new Document(parser.text());
}
// documentText may or may not be present depending on the request parameter

return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText);
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token);
}

private JinaAIRerankResponseEntity() {}
private record DocumentObject(String text) {
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>(
DocumentObject.class.getSimpleName(),
true,
args -> new DocumentObject((String) args[0])
);

static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response";
static {
PARSER.declareString(constructorArg(), new ParseField("text"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
package org.elasticsearch.xpack.inference.services.jinaai.response;

import org.apache.http.HttpResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.hamcrest.MatcherAssert;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand All @@ -26,11 +26,33 @@
public class JinaAIRerankResponseEntityTests extends ESTestCase {

public void testResponseLiteral() throws IOException {
String responseLiteral = """
{
"model": "model",
"results": [
{
"index": 2,
"relevance_score": 0.98005307
},
{
"index": 3,
"relevance_score": 0.27904198
},
{
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
}
""";
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
);

MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
List<RankedDocsResults.RankedDoc> expected = responseLiteralDocs();
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
Expand Down Expand Up @@ -68,7 +90,7 @@ public void testGeneratedResponse() throws IOException {
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
}
Expand All @@ -82,81 +104,92 @@ private ArrayList<RankedDocsResults.RankedDoc> responseLiteralDocs() {
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null));
return list;

};

private final String responseLiteral = """
{
"model": "model",
"results": [
{
"index": 2,
"relevance_score": 0.98005307
},
{
"index": 3,
"relevance_score": 0.27904198
},
{
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
}
""";
}

private final String WASHINGTON_TEXT = "Washington, D.C..";
private final String CAPITAL_PUNISHMENT_TEXT =
"Capital punishment has existed in the United States since before the United States was a country. ";
private final String CARSON_CITY_TEXT = "Carson City is the capital city of the American state of Nevada.";

private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
new RankedDocsResults.RankedDoc(2, 0.98005307F, WASHINGTON_TEXT),
new RankedDocsResults.RankedDoc(3, 0.27904198F, CAPITAL_PUNISHMENT_TEXT),
new RankedDocsResults.RankedDoc(0, 0.10194652F, CARSON_CITY_TEXT)
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these constants be moved to the top of the class and made static? Also, if you do make them all static, for style consistency, responseLiteralDocsWithText should be in all caps.


public void testResponseLiteralWithDocuments() throws IOException {
String responseLiteralWithDocuments = Strings.format("""
{
"model": "model",
"results": [
{
"document": {
"text": "%s"
},
"index": 2,
"relevance_score": 0.98005307
},
{
"document": {
"text": "%s"
},
"index": 3,
"relevance_score": 0.27904198
},
{
"document": {
"text": "%s"
},
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
}
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
);

MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
}

private final String responseLiteralWithDocuments = """
{
"model": "model",
"results": [
{
"document": {
"text": "Washington, D.C.."
},
"index": 2,
"relevance_score": 0.98005307
},
{
"document": {
"text": "Capital punishment has existed in the United States since beforethe United States was a country. "
public void testResponseLiteralWithDocumentsAsString() throws IOException {
String responseLiteralWithDocuments = Strings.format("""
{
"model": "model",
"results": [
{
"document": "%s",
"index": 2,
"relevance_score": 0.98005307
},
"index": 3,
"relevance_score": 0.27904198
},
{
"document": {
"text": "Carson City is the capital city of the American state of Nevada."
{
"document": "%s",
"index": 3,
"relevance_score": 0.27904198
},
"index": 0,
"relevance_score": 0.10194652
{
"document": "%s",
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
],
"usage": {
"total_tokens": 15
}
}
""";
""", WASHINGTON_TEXT, CAPITAL_PUNISHMENT_TEXT, CARSON_CITY_TEXT);
InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
);

private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."),
new RankedDocsResults.RankedDoc(
3,
0.27904198F,
"Capital punishment has existed in the United States since beforethe United States was a country. "
),
new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.")
);
assertThat(parsedResults, instanceOf(RankedDocsResults.class));
assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
}

private ArrayList<Integer> linear(int n) {
ArrayList<Integer> list = new ArrayList<>();
Expand Down