Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/121548.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 121548
summary: Adding support for specifying embedding type to Jina AI service settings
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand All @@ -207,6 +208,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_DRIVER_NODE_DESCRIPTION = def(9_017_0_00);
public static final TransportVersion MULTI_PROJECT = def(9_018_0_00);
public static final TransportVersion STORED_SCRIPT_CONTENT_LENGTH = def(9_019_0_00);
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_020_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;

Expand All @@ -30,6 +31,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
private final JinaAIEmbeddingsTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
private final JinaAIEmbeddingType embeddingType;

public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddingsModel) {
Objects.requireNonNull(embeddingsModel);
Expand All @@ -38,6 +40,7 @@ public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddi
this.input = Objects.requireNonNull(input);
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}

Expand All @@ -46,7 +49,7 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down Expand Up @@ -75,6 +78,10 @@ public boolean[] getTruncationInfo() {
return null;
}

public JinaAIEmbeddingType getEmbeddingType() {
return embeddingType;
}

public static URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(JinaAIUtils.HOST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;

import java.io.IOException;
Expand All @@ -19,9 +20,12 @@

import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage;

public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model)
implements
ToXContentObject {
public record JinaAIEmbeddingsRequestEntity(
List<String> input,
JinaAIEmbeddingsTaskSettings taskSettings,
@Nullable String model,
@Nullable JinaAIEmbeddingType embeddingType
) implements ToXContentObject {

private static final String SEARCH_DOCUMENT = "retrieval.passage";
private static final String SEARCH_QUERY = "retrieval.query";
Expand All @@ -30,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddings
private static final String INPUT_FIELD = "input";
private static final String MODEL_FIELD = "model";
public static final String TASK_TYPE_FIELD = "task";
static final String EMBEDDING_TYPE_FIELD = "embedding_type";

public JinaAIEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Expand All @@ -43,6 +48,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_FIELD, input);
builder.field(MODEL_FIELD, model);

if (embeddingType != null) {
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
}

if (taskSettings.getInputType() != null) {
builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,54 @@

package org.elasticsearch.xpack.inference.external.response.jinaai;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.toLowerCase;

public class JinaAIEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response";

private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(
toLowerCase(JinaAIEmbeddingType.FLOAT),
JinaAIEmbeddingsResponseEntity::parseFloatDataObject,
toLowerCase(JinaAIEmbeddingType.BIT),
JinaAIEmbeddingsResponseEntity::parseBitDataObject,
toLowerCase(JinaAIEmbeddingType.BINARY),
JinaAIEmbeddingsResponseEntity::parseBitDataObject
);
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();

private static String supportedEmbeddingTypes() {
var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
Arrays.sort(validTypes);
return String.join(", ", validTypes);
}

/**
* Parses the JinaAI json response.
* For a request like:
Expand Down Expand Up @@ -73,8 +99,21 @@ public class JinaAIEmbeddingsResponseEntity {
* </code>
* </pre>
*/
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
// embeddings type is not specified anywhere in the response so grab it from the request
JinaAIEmbeddingsRequest embeddingsRequest = (JinaAIEmbeddingsRequest) request;
var embeddingType = embeddingsRequest.getEmbeddingType().toString();
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var embeddingValueParser = EMBEDDING_PARSERS.get(embeddingType);

if (embeddingValueParser == null) {
throw new IllegalStateException(
Strings.format(
"Failed to find a supported embedding type for in the Jina AI embeddings response. Supported types are [%s]",
VALID_EMBEDDING_TYPES_STRING
)
);
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);
Expand All @@ -84,26 +123,66 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult

positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);

List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseEmbeddingObject
);

return new TextEmbeddingFloatResults(embeddingList);
return embeddingValueParser.apply(jsonParser);
}
}

private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
);

return new TextEmbeddingFloatResults(embeddingList);
}

private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);

positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);

List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
var embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);

return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}

private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException {
List<TextEmbeddingByteResults.Embedding> embeddingList = parseList(
jsonParser,
JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject
);

return new TextEmbeddingBitResults(embeddingList);
}

private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);

positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);

var embeddingList = parseList(parser, JinaAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
// parse and discard the rest of the object
consumeUntilObjectEnd(parser);

return TextEmbeddingByteResults.Embedding.of(embeddingList);
}

private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
var parsedByte = parser.shortValue();
checkByteBounds(parsedByte);

return (byte) parsedByte;
}

private static void checkByteBounds(short value) {
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
}
}

private JinaAIEmbeddingsResponseEntity() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
Expand Down Expand Up @@ -294,7 +295,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof JinaAIEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
var maxInputTokens = serviceSettings.maxInputTokens();

var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings(
Expand All @@ -305,7 +306,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
),
similarityToUse,
embeddingSize,
maxInputTokens
maxInputTokens,
serviceSettings.getEmbeddingType()
);

return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
Expand All @@ -322,7 +324,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
*
* @return The default similarity.
*/
static SimilarityMeasure defaultSimilarity() {
static SimilarityMeasure defaultSimilarity(JinaAIEmbeddingType embeddingType) {
if (embeddingType == JinaAIEmbeddingType.BINARY || embeddingType == JinaAIEmbeddingType.BIT) {
return SimilarityMeasure.L2_NORM;
}
return SimilarityMeasure.DOT_PRODUCT;
}

Expand Down
Loading
Loading