diff --git a/docs/changelog/121548.yaml b/docs/changelog/121548.yaml new file mode 100644 index 0000000000000..889a3e81c3f8c --- /dev/null +++ b/docs/changelog/121548.yaml @@ -0,0 +1,5 @@ +pr: 121548 +summary: Adding support for specifying embedding type to Jina AI service settings +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 59794e310eb1a..ebac8c9a326b0 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -181,6 +181,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03); public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04); public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05); + public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -207,6 +208,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_DRIVER_NODE_DESCRIPTION = def(9_017_0_00); public static final TransportVersion MULTI_PROJECT = def(9_018_0_00); public static final TransportVersion STORED_SCRIPT_CONTENT_LENGTH = def(9_019_0_00); + public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_020_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java index d99f15a7703ae..63f325c7e59da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; @@ -30,6 +31,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest { private final JinaAIEmbeddingsTaskSettings taskSettings; private final String model; private final String inferenceEntityId; + private final JinaAIEmbeddingType embeddingType; public JinaAIEmbeddingsRequest(List input, JinaAIEmbeddingsModel embeddingsModel) { Objects.requireNonNull(embeddingsModel); @@ -38,6 +40,7 @@ public JinaAIEmbeddingsRequest(List input, JinaAIEmbeddingsModel embeddi this.input = Objects.requireNonNull(input); taskSettings = embeddingsModel.getTaskSettings(); model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); inferenceEntityId = embeddingsModel.getInferenceEntityId(); } @@ -46,7 +49,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -75,6 +78,10 @@ public boolean[] getTruncationInfo() { return null; } + public JinaAIEmbeddingType getEmbeddingType() { + return embeddingType; + } + public static URI buildDefaultUri() throws URISyntaxException { return new URIBuilder().setScheme("https") .setHost(JinaAIUtils.HOST) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java index d4f98f1eb52ca..8d4c094206e29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import java.io.IOException; @@ -19,9 +20,12 @@ import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage; -public record JinaAIEmbeddingsRequestEntity(List input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model) - implements - ToXContentObject { +public record JinaAIEmbeddingsRequestEntity( + List input, + JinaAIEmbeddingsTaskSettings taskSettings, + @Nullable String model, + @Nullable JinaAIEmbeddingType embeddingType +) implements ToXContentObject { private static final String SEARCH_DOCUMENT = "retrieval.passage"; private static final String SEARCH_QUERY = "retrieval.query"; @@ -30,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(List input, JinaAIEmbeddings private static final String INPUT_FIELD = "input"; private static final String MODEL_FIELD = "model"; public static final String TASK_TYPE_FIELD = "task"; + static final String EMBEDDING_TYPE_FIELD = "embedding_type"; public JinaAIEmbeddingsRequestEntity { Objects.requireNonNull(input); @@ -43,6 +48,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_FIELD, input); builder.field(MODEL_FIELD, model); + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString()); + } + if (taskSettings.getInputType() != null) { builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java index b1782bb560ac0..91c7010330439 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java @@ -9,28 +9,54 @@ package org.elasticsearch.xpack.inference.external.response.jinaai; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest; import org.elasticsearch.xpack.inference.external.response.XContentUtils; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import java.io.IOException; +import java.util.Arrays; import java.util.List; +import java.util.Map; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.toLowerCase; public class JinaAIEmbeddingsResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response"; + private static final Map> EMBEDDING_PARSERS = Map.of( + toLowerCase(JinaAIEmbeddingType.FLOAT), + JinaAIEmbeddingsResponseEntity::parseFloatDataObject, + toLowerCase(JinaAIEmbeddingType.BIT), + JinaAIEmbeddingsResponseEntity::parseBitDataObject, + toLowerCase(JinaAIEmbeddingType.BINARY), + JinaAIEmbeddingsResponseEntity::parseBitDataObject + ); + private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); + + private static String supportedEmbeddingTypes() { + var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new); + Arrays.sort(validTypes); + return String.join(", ", validTypes); + } + /** * Parses the JinaAI json response. * For a request like: @@ -73,8 +99,21 @@ public class JinaAIEmbeddingsResponseEntity { * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + // embeddings type is not specified anywhere in the response so grab it from the request + JinaAIEmbeddingsRequest embeddingsRequest = (JinaAIEmbeddingsRequest) request; + var embeddingType = embeddingsRequest.getEmbeddingType().toString(); var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var embeddingValueParser = EMBEDDING_PARSERS.get(embeddingType); + + if (embeddingValueParser == null) { + throw new IllegalStateException( + Strings.format( + "Failed to find a supported embedding type for in the Jina AI embeddings response. Supported types are [%s]", + VALID_EMBEDDING_TYPES_STRING + ) + ); + } try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { moveToFirstToken(jsonParser); @@ -84,26 +123,66 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( - jsonParser, - JinaAIEmbeddingsResponseEntity::parseEmbeddingObject - ); - - return new TextEmbeddingFloatResults(embeddingList); + return embeddingValueParser.apply(jsonParser); } } - private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException { + List embeddingList = parseList( + jsonParser, + JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject + ); + + return new TextEmbeddingFloatResults(embeddingList); + } + + private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + var embeddingValuesList = parseList(parser, XContentUtils::parseFloat); // parse and discard the rest of the object consumeUntilObjectEnd(parser); return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); } + private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException { + List embeddingList = parseList( + jsonParser, + JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject + ); + + return new TextEmbeddingBitResults(embeddingList); + } + + private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + var embeddingList = parseList(parser, JinaAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); + + return TextEmbeddingByteResults.Embedding.of(embeddingList); + } + + private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + var parsedByte = parser.shortValue(); + checkByteBounds(parsedByte); + + return (byte) parsedByte; + } + + private static void checkByteBounds(short value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + private JinaAIEmbeddingsResponseEntity() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 71917151623f0..6e3c830ae764f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; @@ -294,7 +295,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { var serviceSettings = embeddingsModel.getServiceSettings(); var similarityFromModel = serviceSettings.similarity(); - var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel; + var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel; var maxInputTokens = serviceSettings.maxInputTokens(); var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings( @@ -305,7 +306,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { ), similarityToUse, embeddingSize, - maxInputTokens + maxInputTokens, + serviceSettings.getEmbeddingType() ); return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); @@ -322,7 +324,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { * * @return The default similarity. */ - static SimilarityMeasure defaultSimilarity() { + static SimilarityMeasure defaultSimilarity(JinaAIEmbeddingType embeddingType) { + if (embeddingType == JinaAIEmbeddingType.BINARY || embeddingType == JinaAIEmbeddingType.BIT) { + return SimilarityMeasure.L2_NORM; + } return SimilarityMeasure.DOT_PRODUCT; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java new file mode 100644 index 0000000000000..e9c869d59a6fc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.Locale; +import java.util.Map; + +/** + * Defines the type of embedding that the Jina AI API should return for a request. + * + */ +public enum JinaAIEmbeddingType { + /** + * Use this when you want to get back the default float embeddings. + */ + FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT), + /** + * Use this when you want to get back binary embeddings. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT); + + private static final class RequestConstants { + private static final String FLOAT = "float"; + private static final String BIT = "binary"; + } + + private static final Map ELEMENT_TYPE_TO_JINA_AI_EMBEDDING = Map.of( + DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, + DenseVectorFieldMapper.ElementType.BIT, + BIT + ); + static final EnumSet SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf( + ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.keySet() + ); + + private final DenseVectorFieldMapper.ElementType elementType; + private final String requestString; + + JinaAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) { + this.elementType = elementType; + this.requestString = requestString; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public String toRequestString() { + return requestString; + } + + public static String toLowerCase(JinaAIEmbeddingType type) { + return type.toString().toLowerCase(Locale.ROOT); + } + + public static JinaAIEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static JinaAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) { + var embedding = ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.get(elementType); + + if (embedding == null) { + var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream() + .map(value -> value.toString().toLowerCase(Locale.ROOT)) + .toArray(String[]::new); + Arrays.sort(validElementTypes); + + throw new IllegalArgumentException( + Strings.format( + "Element type [%s] does not map to a Jina AI embedding value, must be one of [%s]", + elementType, + String.join(", ", validElementTypes) + ) + ); + } + + return embedding; + } + + public DenseVectorFieldMapper.ElementType toElementType() { + return elementType; + } + + /** + * Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet + * introduced it will be defaulted FLOAT. + * + * @param embeddingType the value to translate if necessary + * @param version the version that dictates the translation + * @return the embedding type that is known to the version passed in + */ + public static JinaAIEmbeddingType translateToVersion(JinaAIEmbeddingType embeddingType, TransportVersion version) { + if (version.onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) + || version.isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) { + return embeddingType; + } + + return FLOAT; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java index 449da72674be4..09164f9bbc45a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java @@ -23,18 +23,22 @@ import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import java.io.IOException; +import java.util.EnumSet; import java.util.Map; import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { public static final String NAME = "jinaai_embeddings_service_settings"; + static final String EMBEDDING_TYPE = "embedding_type"; + public static JinaAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); @@ -42,28 +46,47 @@ public static JinaAIEmbeddingsServiceSettings fromMap(Map map, C Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + JinaAIEmbeddingType embeddingTypes = parseEmbeddingType(map, validationException); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens); + return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes); + } + + static JinaAIEmbeddingType parseEmbeddingType(Map map, ValidationException validationException) { + return Objects.requireNonNullElse( + extractOptionalEnum( + map, + EMBEDDING_TYPE, + ModelConfigurations.SERVICE_SETTINGS, + JinaAIEmbeddingType::fromString, + EnumSet.allOf(JinaAIEmbeddingType.class), + validationException + ), + JinaAIEmbeddingType.FLOAT + ); } private final JinaAIServiceSettings commonSettings; private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; + private final JinaAIEmbeddingType embeddingType; public JinaAIEmbeddingsServiceSettings( JinaAIServiceSettings commonSettings, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens + @Nullable Integer maxInputTokens, + @Nullable JinaAIEmbeddingType embeddingType ) { this.commonSettings = commonSettings; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; + this.embeddingType = embeddingType != null ? embeddingType : JinaAIEmbeddingType.FLOAT; } public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { @@ -71,6 +94,11 @@ public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); + + this.embeddingType = (in.getTransportVersion().onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) + || in.getTransportVersion().isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) + ? Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT) + : JinaAIEmbeddingType.FLOAT; } public JinaAIServiceSettings getCommonSettings() { @@ -96,9 +124,13 @@ public String modelId() { return commonSettings.modelId(); } + public JinaAIEmbeddingType getEmbeddingType() { + return embeddingType; + } + @Override public DenseVectorFieldMapper.ElementType elementType() { - return DenseVectorFieldMapper.ElementType.FLOAT; + return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType(); } @Override @@ -120,6 +152,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE, embeddingType); + } + builder.endObject(); return builder; } @@ -127,7 +163,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { commonSettings.toXContentFragmentOfExposedFields(builder, params); - + if (embeddingType != null) { + builder.field(EMBEDDING_TYPE, embeddingType); + } return builder; } @@ -142,6 +180,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); + + if (out.getTransportVersion().onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) + || out.getTransportVersion().isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) { + out.writeOptionalEnum(JinaAIEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion())); + } } @Override @@ -152,11 +195,12 @@ public boolean equals(Object o) { return Objects.equals(commonSettings, that.commonSettings) && Objects.equals(similarity, that.similarity) && Objects.equals(dimensions, that.dimensions) - && Objects.equals(maxInputTokens, that.maxInputTokens); + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(embeddingType, that.embeddingType); } @Override public int hashCode() { - return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens); + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java index 7f3f6e5cdeb82..6c6e3f85ce020 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.hamcrest.MatcherAssert; @@ -23,25 +24,67 @@ public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), new JinaAIEmbeddingsTaskSettings(InputType.INGEST), "model"); + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + "model", + JinaAIEmbeddingType.FLOAT + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, is(""" - {"input":["abc"],"model":"model","task":"retrieval.passage"}""")); + {"input":["abc"],"model":"model","embedding_type":"float","task":"retrieval.passage"}""")); } public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model"); + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "model", + JinaAIEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","embedding_type":"float"}""")); + } + + public void testXContent_EmbeddingTypesBit() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "model", + JinaAIEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","embedding_type":"binary"}""")); + } + + public void testXContent_EmbeddingTypesBinary() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity( + List.of("abc"), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + "model", + JinaAIEmbeddingType.BINARY + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, is(""" - {"input":["abc"],"model":"model"}""")); + {"input":["abc"],"model":"model","embedding_type":"binary"}""")); } public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java index 05194ceb0de9e..00b76e6dca5ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; @@ -29,7 +30,15 @@ public class JinaAIEmbeddingsRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws IOException { var request = createRequest( List.of("abc"), - JinaAIEmbeddingsModelTests.createModel("url", "secret", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model") + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) ); var httpRequest = request.createHttpRequest(); @@ -46,13 +55,21 @@ public void testCreateRequest_UrlDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model"))); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "embedding_type", "float"))); } public void testCreateRequest_AllOptionsDefined() throws IOException { var request = createRequest( List.of("abc"), - JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model") + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) ); var httpRequest = request.createHttpRequest(); @@ -69,13 +86,58 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.passage"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.passage", "embedding_type", "float")) + ); } public void testCreateRequest_InputTypeSearch() throws IOException { var request = createRequest( List.of("abc"), - JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model") + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(JinaAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "float")) + ); + } + + public void testCreateRequest_EmbeddingTypeBit() throws IOException { + var request = createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.BIT + ) ); var httpRequest = request.createHttpRequest(); @@ -92,7 +154,44 @@ public void testCreateRequest_InputTypeSearch() throws IOException { ); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "binary")) + ); + } + + public void testCreateRequest_EmbeddingTypeBinary() throws IOException { + var request = createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.BINARY + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(JinaAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "binary")) + ); } public static JinaAIEmbeddingsRequest createRequest(List input, JinaAIEmbeddingsModel model) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java index a5f7c4eadae28..eaaa69df0cbf7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java @@ -9,15 +9,22 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequestTests; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -44,13 +51,25 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } """; - TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); assertThat( - parsedResults.embeddings(), + ((TextEmbeddingFloatResults) parsedResults).embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) ); } @@ -85,13 +104,25 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException } """; - TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); + assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class)); assertThat( - parsedResults.embeddings(), + ((TextEmbeddingFloatResults) parsedResults).embeddings(), is( List.of( new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), @@ -126,7 +157,18 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() { var thrownException = expectThrows( IllegalStateException.class, () -> JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); @@ -159,7 +201,18 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() { var thrownException = expectThrows( ParsingException.class, () -> JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); @@ -195,7 +248,18 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { var thrownException = expectThrows( IllegalStateException.class, () -> JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); @@ -227,7 +291,18 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { var thrownException = expectThrows( ParsingException.class, () -> JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); @@ -238,7 +313,7 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAString() { ); } - public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException { + public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException { String responseJson = """ { "object": "list", @@ -247,7 +322,11 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio "object": "embedding", "index": 0, "embedding": [ - 1 + -55, + 74, + 101, + 67, + 83 ] } ], @@ -259,15 +338,29 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio } """; - TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.BINARY + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F })))); + assertThat( + ((TextEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ); } - public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { + public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException { String responseJson = """ { "object": "list", @@ -276,7 +369,11 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti "object": "embedding", "index": 0, "embedding": [ - 40294967295 + -55, + 74, + 101, + 67, + 83 ] } ], @@ -288,12 +385,26 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti } """; - TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.BIT + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat( + ((TextEmbeddingBitResults) parsedResults).embeddings(), + is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) + ); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { @@ -320,7 +431,18 @@ public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { var thrownException = expectThrows( ParsingException.class, () -> JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.BINARY + ) + ), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) ); @@ -372,8 +494,19 @@ public void testFieldsInDifferentOrderServer() throws IOException { } }"""; - TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( - mock(Request.class), + TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse( + JinaAIEmbeddingsRequestTests.createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel( + "url", + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ) + ), new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 3387b8b73978f..5d2ab9e6d2f57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettingsTests; @@ -112,6 +113,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModel() throws IOExce var embeddingsModel = (JinaAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, e -> fail("Model parsing should have succeeded " + e.getMessage())); @@ -120,7 +122,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModel() throws IOExce "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), getSecretSettingsMap("secret") ), @@ -138,6 +140,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSett var embeddingsModel = (JinaAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT)); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); @@ -148,7 +151,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSett "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") @@ -167,6 +170,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSett var embeddingsModel = (JinaAIEmbeddingsModel) model; MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.BIT)); MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); @@ -177,7 +181,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSett "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.BIT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), getSecretSettingsMap("secret") ), @@ -204,7 +208,7 @@ public void testParseRequestConfig_OptionalTaskSettings() throws IOException { "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), getSecretSettingsMap("secret") ), modelListener @@ -224,7 +228,7 @@ public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOExceptio "id", TaskType.SPARSE_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), @@ -243,7 +247,7 @@ private static ActionListener getModelListenerForException(Class excep public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -259,7 +263,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { try (var service = createJinaAIService()) { - var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT); serviceSettings.put("extra_key", "value"); var config = getRequestConfigMap( @@ -282,7 +286,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -302,7 +306,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap secretSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), secretSettingsMap ); @@ -330,7 +334,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() thr "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), @@ -343,7 +347,7 @@ public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() thr public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -368,7 +372,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") @@ -395,7 +399,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhe public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -421,7 +425,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhe public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -446,7 +450,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), getSecretSettingsMap("secret") ); @@ -471,7 +475,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWit public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), getSecretSettingsMap("secret") ); @@ -500,7 +504,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), secretSettingsMap ); @@ -525,7 +529,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -550,7 +554,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -582,7 +586,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -607,7 +611,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -626,7 +630,7 @@ public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOEx public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), createRandomChunkingSettingsMap() ); @@ -647,7 +651,7 @@ public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSe public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -667,7 +671,7 @@ public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSe public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() ); @@ -686,7 +690,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) ); @@ -705,7 +709,7 @@ public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() t public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createJinaAIService()) { var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() ); persistedConfig.config().put("extra_key", "value"); @@ -724,7 +728,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createJinaAIService()) { - var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -750,7 +754,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT), taskSettingsMap ); @@ -834,7 +838,8 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 1, - "jina-clip-v2" + "jina-clip-v2", + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -850,7 +855,8 @@ public void testCheckModelConfig_UpdatesDimensions() throws IOException { JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, 10, 2, - "jina-clip-v2" + "jina-clip-v2", + JinaAIEmbeddingType.FLOAT ) ) ); @@ -891,7 +897,8 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th 10, 1, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -908,7 +915,8 @@ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() th 10, 2, "jina-clip-v2", - SimilarityMeasure.DOT_PRODUCT + SimilarityMeasure.DOT_PRODUCT, + JinaAIEmbeddingType.FLOAT ) ) ); @@ -949,7 +957,8 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi 10, 1, "jina-clip-v2", - SimilarityMeasure.COSINE + SimilarityMeasure.COSINE, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); @@ -966,7 +975,8 @@ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosi 10, 2, "jina-clip-v2", - SimilarityMeasure.COSINE + SimilarityMeasure.COSINE, + JinaAIEmbeddingType.FLOAT ) ) ); @@ -986,6 +996,7 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { var embeddingSize = randomNonNegativeInt(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); var model = JinaAIEmbeddingsModelTests.createModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -993,12 +1004,15 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si randomNonNegativeInt(), randomNonNegativeInt(), randomAlphaOfLength(10), - similarityMeasure + similarityMeasure, + embeddingType ); Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); - SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? JinaAIService.defaultSimilarity() : similarityMeasure; + SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null + ? JinaAIService.defaultSimilarity(embeddingType) + : similarityMeasure; assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity()); assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); } @@ -1023,7 +1037,8 @@ public void testInfer_Embedding_UnauthorisedResponse() throws IOException { 1024, 1024, "model", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1110,7 +1125,8 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { 1024, 1024, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1137,7 +1153,10 @@ public void testInfer_Embedding_Get_Response_Ingest() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.passage"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.passage", "embedding_type", "float")) + ); } } @@ -1175,7 +1194,8 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { 1024, 1024, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1202,7 +1222,10 @@ public void testInfer_Embedding_Get_Response_Search() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.query"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.query", "embedding_type", "float")) + ); } } @@ -1224,7 +1247,8 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { 1024, 1024, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1251,7 +1275,10 @@ public void testInfer_Embedding_Get_Response_clustering() throws IOException { MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "separation"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "separation", "embedding_type", "float")) + ); } } @@ -1289,7 +1316,8 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException 1024, 1024, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); @@ -1307,7 +1335,7 @@ public void testInfer_Embedding_Get_Response_NullInputType() throws IOException MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2"))); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "embedding_type", "float"))); } } @@ -1689,7 +1717,8 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings 1024, 1024, "jina-clip-v2", - null + null, + JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1715,7 +1744,7 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2"))); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "embedding_type", "float"))); } } @@ -1727,7 +1756,8 @@ public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws createRandomChunkingSettings(), 1024, 1024, - "jina-clip-v2" + "jina-clip-v2", + JinaAIEmbeddingType.FLOAT ); test_Embedding_ChunkedInfer_BatchesCalls(model); @@ -1741,7 +1771,8 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept null, 1024, 1024, - "jina-clip-v2" + "jina-clip-v2", + JinaAIEmbeddingType.FLOAT ); test_Embedding_ChunkedInfer_BatchesCalls(model); @@ -1831,12 +1862,20 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2"))); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2", "embedding_type", "float")) + ); } } - public void testDefaultSimilarity() { - assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity()); + public void testDefaultSimilarity_BinaryEmbedding() { + assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BINARY)); + assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BIT)); + } + + public void testDefaultSimilarity_NotBinaryEmbedding() { + assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.FLOAT)); } @SuppressWarnings("checkstyle:LineLength") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java index 58455bb1f54ea..a0044c40e6f1f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java @@ -25,69 +25,171 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase { public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { - var model = createModel("url", "api_key", null, null, "model"); + var model = createModel("url", "api_key", null, null, "model", JinaAIEmbeddingType.FLOAT); var overriddenModel = JinaAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED); MatcherAssert.assertThat(overriddenModel, is(model)); } public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { - var model = createModel("url", "api_key", null, null, "model"); + var model = createModel("url", "api_key", null, null, "model", JinaAIEmbeddingType.FLOAT); var overriddenModel = JinaAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED); MatcherAssert.assertThat(overriddenModel, is(model)); } public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings((InputType) null), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings((InputType) null), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings((InputType) null), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings((InputType) null), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { - var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var model = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); - var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + var expectedModel = createModel( + "url", + "api_key", + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + null, + null, + "model", + JinaAIEmbeddingType.FLOAT + ); MatcherAssert.assertThat(overriddenModel, is(expectedModel)); } - public static JinaAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) { - return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model); + public static JinaAIEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable String model, + @Nullable JinaAIEmbeddingType embeddingType + ) { + return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model, embeddingType); } public static JinaAIEmbeddingsModel createModel( @@ -95,9 +197,10 @@ public static JinaAIEmbeddingsModel createModel( String apiKey, @Nullable Integer tokenLimit, @Nullable Integer dimensions, - String model + String model, + @Nullable JinaAIEmbeddingType embeddingType ) { - return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model); + return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model, embeddingType); } public static JinaAIEmbeddingsModel createModel( @@ -107,7 +210,8 @@ public static JinaAIEmbeddingsModel createModel( ChunkingSettings chunkingSettings, @Nullable Integer tokenLimit, @Nullable Integer dimensions, - String model + String model, + @Nullable JinaAIEmbeddingType embeddingType ) { return new JinaAIEmbeddingsModel( "id", @@ -116,7 +220,8 @@ public static JinaAIEmbeddingsModel createModel( new JinaAIServiceSettings(url, model, null), SimilarityMeasure.DOT_PRODUCT, dimensions, - tokenLimit + tokenLimit, + embeddingType ), taskSettings, chunkingSettings, @@ -130,7 +235,8 @@ public static JinaAIEmbeddingsModel createModel( JinaAIEmbeddingsTaskSettings taskSettings, @Nullable Integer tokenLimit, @Nullable Integer dimensions, - String model + String model, + @Nullable JinaAIEmbeddingType embeddingType ) { return new JinaAIEmbeddingsModel( "id", @@ -139,7 +245,8 @@ public static JinaAIEmbeddingsModel createModel( new JinaAIServiceSettings(url, model, null), SimilarityMeasure.DOT_PRODUCT, dimensions, - tokenLimit + tokenLimit, + embeddingType ), taskSettings, null, @@ -154,12 +261,19 @@ public static JinaAIEmbeddingsModel createModel( @Nullable Integer tokenLimit, @Nullable Integer dimensions, String model, - @Nullable SimilarityMeasure similarityMeasure + @Nullable SimilarityMeasure similarityMeasure, + @Nullable JinaAIEmbeddingType embeddingType ) { return new JinaAIEmbeddingsModel( "id", "service", - new JinaAIEmbeddingsServiceSettings(new JinaAIServiceSettings(url, model, null), similarityMeasure, dimensions, tokenLimit), + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(url, model, null), + similarityMeasure, + dimensions, + tokenLimit, + embeddingType + ), taskSettings, null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java index 6847d249a57a0..a1bcc6d9636c0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java @@ -7,16 +7,18 @@ package org.elasticsearch.xpack.inference.services.jinaai.embeddings; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -35,7 +37,7 @@ import static org.hamcrest.Matchers.is; -public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { +public class JinaAIEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase { public static JinaAIEmbeddingsServiceSettings createRandom() { SimilarityMeasure similarityMeasure = null; Integer dims = null; @@ -44,8 +46,9 @@ public static JinaAIEmbeddingsServiceSettings createRandom() { Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); var commonSettings = JinaAIServiceSettingsTests.createRandom(); + var embeddingType = randomFrom(JinaAIEmbeddingType.values()); - return new JinaAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens); + return new JinaAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens, embeddingType); } public void testFromMap() { @@ -79,7 +82,8 @@ public void testFromMap() { new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null), SimilarityMeasure.DOT_PRODUCT, dims, - maxInputTokens + maxInputTokens, + JinaAIEmbeddingType.FLOAT ) ) ); @@ -116,7 +120,48 @@ public void testFromMap_WithModelId() { new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null), SimilarityMeasure.DOT_PRODUCT, dims, - maxInputTokens + maxInputTokens, + JinaAIEmbeddingType.FLOAT + ) + ) + ); + } + + public void testFromMap_WithEmbeddingType() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + JinaAIServiceSettings.MODEL_ID, + model, + JinaAIEmbeddingsServiceSettings.EMBEDDING_TYPE, + JinaAIEmbeddingType.BIT.toString() + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + JinaAIEmbeddingType.BIT ) ) ); @@ -146,7 +191,8 @@ public void testToXContent_WritesAllValues() throws IOException { new JinaAIServiceSettings("url", "model", new RateLimitSettings(3)), SimilarityMeasure.COSINE, 5, - 10 + 10, + JinaAIEmbeddingType.FLOAT ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -154,7 +200,8 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" {"url":"url","model_id":"model",""" + """ - "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10}""")); + "rate_limit":{"requests_per_minute":3},""" + """ + "similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}""")); } @Override @@ -172,6 +219,23 @@ protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsService return randomValueOtherThan(instance, JinaAIEmbeddingsServiceSettingsTests::createRandom); } + @Override + protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddingsServiceSettings instance, TransportVersion version) { + if (version.onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED) + || version.isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) { + return instance; + } + + // default to null embedding type if node is on a version before embedding type was introduced + return new JinaAIEmbeddingsServiceSettings( + instance.getCommonSettings(), + instance.similarity(), + instance.dimensions(), + instance.maxInputTokens(), + null + ); + } + @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { List entries = new ArrayList<>(); @@ -180,8 +244,17 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } - public static Map getServiceSettingsMap(@Nullable String url, String model) { + public static Map getServiceSettingsMap( + @Nullable String url, + String model, + @Nullable JinaAIEmbeddingType embeddingType + ) { var map = new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(url, model)); + + if (embeddingType != null) { + map.put(JinaAIEmbeddingsServiceSettings.EMBEDDING_TYPE, embeddingType.toString()); + } + return map; } }