Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/136751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 136751
summary: Adjust jinaai rerank response parser to handle document field as string or
object
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,27 @@

package org.elasticsearch.xpack.inference.services.jinaai.response;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField;
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class JinaAIRerankResponseEntity {

private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class);

/**
* Parses the JinaAI ranked response.
*
Expand All @@ -44,6 +41,7 @@ public class JinaAIRerankResponseEntity {
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
* <p>
* The response will look like (without whitespace):
* <pre>
* {
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
* "results": [
Expand Down Expand Up @@ -71,86 +69,114 @@ public class JinaAIRerankResponseEntity {
* ],
* "usage": {"total_tokens": 15}
* }
* </pre>
*
* Or like this where documents is a string:
* <pre>
* {
* "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703",
* "results": [
* {
* "document": "Washington, D.C. is the capital of the United States. It is a federal district.",
* "index": 2,
* "relevance_score": 0.98005307
* },
* {
* "document": "abc",
* "index": 3,
* "relevance_score": 0.27904198
* },
* {
* "document": "Carson City is the capital city of the American state of Nevada.",
* "index": 0,
* "relevance_score": 0.10194652
* }
* ],
* "usage": {
* "total_tokens": 15
* }
* }
* </pre>
*
* This parsing logic handles both cases.
*
* @param response the http response from JinaAI
* @return the parsed response
* @throws IOException if there is an error parsing the response
*/
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth updating the javadoc for this method to include the new behaviour? If you do decide to update it, putting <pre> tags around the JSON parts would be helpful, to make the javadoc more readable in the IDE. Right now, when mousing over the method name to show the javadoc, the JSON is formatted all on one line, which is very difficult to read.

var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);

XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return Response.PARSER.apply(p, null).toRankedDocsResults();
}
}

positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE);
private record Response(List<ResultItem> results) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
Response.class.getSimpleName(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using getSimpleName() here will result in any errors during parsing referencing Response which is pretty vague if we need to use it for debugging. Using any of the other "name" methods would be too verbose, but is there some way we could include the name of the parent class? Or is that overkill?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, typically we just do the getSimpleName(). I think the error will have the stacktrace though which should get us to the code that is failing 🤔

true,
args -> new Response((List<ResultItem>) args[0])
);

token = jsonParser.currentToken();
if (token == XContentParser.Token.START_ARRAY) {
return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject));
} else {
throwUnknownToken(token, jsonParser);
}
static {
PARSER.declareObjectArray(constructorArg(), ResultItem.PARSER::apply, new ParseField("results"));
}

// This should never be reached. The above code should either return successfully or hit the throwUnknownToken
// or throw a parsing exception
throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response");
public RankedDocsResults toRankedDocsResults() {
List<RankedDocsResults.RankedDoc> rankedDocs = results.stream()
.map(
item -> new RankedDocsResults.RankedDoc(
item.index(),
item.relevanceScore(),
item.document() != null ? item.document().text() : null
)
)
.toList();
return new RankedDocsResults(rankedDocs);
}
}

private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
int index = -1;
float relevanceScore = -1;
String documentText = null;
parser.nextToken();
while (parser.currentToken() != XContentParser.Token.END_OBJECT) {
if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
switch (parser.currentName()) {
case "index":
parser.nextToken(); // move to VALUE_NUMBER
index = parser.intValue();
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
break;
case "relevance_score":
parser.nextToken(); // move to VALUE_NUMBER
relevanceScore = parser.floatValue();
parser.nextToken(); // move to next FIELD_NAME or END_OBJECT
break;
case "document":
parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
do {
if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) {
parser.nextToken(); // move to VALUE_STRING
documentText = parser.text();
}
} while (parser.nextToken() != XContentParser.Token.END_OBJECT);
parser.nextToken();// move past END_OBJECT
// parser should now be at the next FIELD_NAME or END_OBJECT
break;
default:
throwUnknownField(parser.currentName(), parser);
}
} else {
parser.nextToken();
}
private record ResultItem(int index, float relevanceScore, @Nullable Document document) {
public static final ConstructingObjectParser<ResultItem, Void> PARSER = new ConstructingObjectParser<>(
ResultItem.class.getSimpleName(),
true,
args -> new ResultItem((Integer) args[0], (Float) args[1], (Document) args[2])
);

static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> parseDocument(p),
new ParseField("document"),
ObjectParser.ValueType.OBJECT_OR_STRING
);
}
}

if (index == -1) {
logger.warn("Failed to find required field [index] in JinaAI rerank response");
}
if (relevanceScore == -1) {
logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response");
private record Document(String text) {}

private static Document parseDocument(XContentParser parser) throws IOException {
var token = parser.currentToken();
if (token == XContentParser.Token.START_OBJECT) {
return new Document(DocumentObject.PARSER.apply(parser, null).text());
} else if (token == XContentParser.Token.VALUE_STRING) {
return new Document(parser.text());
}
// documentText may or may not be present depending on the request parameter

return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText);
throw new XContentParseException(parser.getTokenLocation(), "Expected an object or string for document field, but got: " + token);
}

private JinaAIRerankResponseEntity() {}
private record DocumentObject(String text) {
public static final ConstructingObjectParser<DocumentObject, Void> PARSER = new ConstructingObjectParser<>(
DocumentObject.class.getSimpleName(),
true,
args -> new DocumentObject((String) args[0])
);

static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response";
static {
PARSER.declareString(constructorArg(), new ParseField("text"));
}
}
}
Loading