diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index fd3852d8ecca3..1367a7db1af93 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -334,6 +334,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00); public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES = def(9_116_0_00); public static final TransportVersion ESQL_LOCAL_RELATION_WITH_NEW_BLOCKS = def(9_117_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_118_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index a3e9d69654f1f..adbec49328804 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1085,5 +1085,11 @@ public static void validateInputTypeAgainstAllowlist( } } + public static void checkByteBounds(short value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java index b23f515055b9d..384acefc50e3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -46,7 +46,7 @@ public CustomModel( inferenceId, taskType, service, - CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomServiceSettings.fromMap(serviceSettings, context, taskType), CustomTaskSettings.fromMap(taskSettings), CustomSecretSettings.fromMap(secrets) ); @@ -66,7 +66,7 @@ public CustomModel( inferenceId, taskType, service, - CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomServiceSettings.fromMap(serviceSettings, context, taskType), CustomTaskSettings.fromMap(taskSettings), CustomSecretSettings.fromMap(secrets), chunkingSettings diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4a4166cf65ed3..4e81d37ead3ad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -333,12 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - similarityToUse, - embeddingSize, - serviceSettings.getMaxInputTokens(), - serviceSettings.elementType() - ), + new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens()), serviceSettings.getUrl(), serviceSettings.getHeaders(), serviceSettings.getQueryParameters(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java new file mode 100644 index 0000000000000..143a0b2dcb6fe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.util.Locale; + +public enum CustomServiceEmbeddingType { + /** + * Use this when you want to get back the default float embeddings. + */ + FLOAT(DenseVectorFieldMapper.ElementType.FLOAT), + /** + * Use this when you want to get back signed int8 embeddings. + */ + BYTE(DenseVectorFieldMapper.ElementType.BYTE), + /** + * Use this when you want to get back binary embeddings. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT); + + private final DenseVectorFieldMapper.ElementType elementType; + + CustomServiceEmbeddingType(DenseVectorFieldMapper.ElementType elementType) { + this.elementType = elementType; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public DenseVectorFieldMapper.ElementType toElementType() { + return elementType; + } + + public static CustomServiceEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 931eb3b798553..8b799e472d512 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10; - public static CustomServiceSettings fromMap( - Map map, - ConfigurationParseContext context, - TaskType taskType, - String inferenceId - ) { + public static CustomServiceSettings fromMap(Map map, ConfigurationParseContext context, TaskType taskType) { ValidationException validationException = new ValidationException(); var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException); @@ -137,22 +132,12 @@ public static CustomServiceSettings fromMap( ); } - public record TextEmbeddingSettings( - @Nullable SimilarityMeasure similarityMeasure, - @Nullable Integer dimensions, - @Nullable Integer maxInputTokens, - @Nullable DenseVectorFieldMapper.ElementType elementType - ) implements ToXContentFragment, Writeable { + public static class TextEmbeddingSettings implements ToXContentFragment, Writeable { // This specifies float for the element type but null for all other settings - public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( - null, - null, - null, - DenseVectorFieldMapper.ElementType.FLOAT - ); + public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null); // This refers to settings that are not related to the text embedding task type (all the settings should be null) - public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); + public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null); public static TextEmbeddingSettings fromMap(Map map, TaskType taskType, ValidationException validationException) { if (taskType != TaskType.TEXT_EMBEDDING) { @@ -162,16 +147,31 @@ public static TextEmbeddingSettings fromMap(Map map, TaskType ta SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT); + return new TextEmbeddingSettings(similarity, dims, maxInputTokens); + } + + private final SimilarityMeasure similarityMeasure; + private final Integer dimensions; + private final Integer maxInputTokens; + + public TextEmbeddingSettings( + @Nullable SimilarityMeasure similarityMeasure, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this.similarityMeasure = similarityMeasure; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; } public TextEmbeddingSettings(StreamInput in) throws IOException { - this( - in.readOptionalEnum(SimilarityMeasure.class), - in.readOptionalVInt(), - in.readOptionalVInt(), - in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) - ); + this.similarityMeasure = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class); + } } @Override @@ -179,7 +179,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(similarityMeasure); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - out.writeOptionalEnum(elementType); + + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { + out.writeOptionalEnum(null); + } } @Override @@ -193,8 +196,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + return builder; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + TextEmbeddingSettings that = (TextEmbeddingSettings) o; + return similarityMeasure == that.similarityMeasure + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(similarityMeasure, dimensions, maxInputTokens); + } } private final TextEmbeddingSettings textEmbeddingSettings; @@ -300,7 +318,12 @@ public Integer dimensions() { @Override public DenseVectorFieldMapper.ElementType elementType() { - return textEmbeddingSettings.elementType; + var embeddingType = responseJsonParser.getEmbeddingType(); + if (embeddingType != null) { + return embeddingType.toElementType(); + } + + return null; } public Integer getMaxInputTokens() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java index 99b035ef056c7..ac9a223de17c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java @@ -22,7 +22,9 @@ import java.util.Objects; import java.util.function.BiFunction; -public abstract class BaseCustomResponseParser implements CustomResponseParser { +import static org.elasticsearch.xpack.inference.services.ServiceUtils.checkByteBounds; + +public abstract class BaseCustomResponseParser implements CustomResponseParser { @Override public InferenceServiceResults parse(HttpResult response) throws IOException { @@ -36,7 +38,7 @@ public InferenceServiceResults parse(HttpResult response) throws IOException { } } - protected abstract T transform(Map extractedField); + protected abstract InferenceServiceResults transform(Map extractedField); static List validateList(Object obj, String fieldName) { validateNonNull(obj, fieldName); @@ -97,6 +99,21 @@ static Float toFloat(Object obj, String fieldName) { return toNumber(obj, fieldName).floatValue(); } + static List convertToListOfBits(Object obj, String fieldName) { + return convertToListOfBytes(obj, fieldName); + } + + static List convertToListOfBytes(Object obj, String fieldName) { + return castList(validateList(obj, fieldName), BaseCustomResponseParser::toByte, fieldName); + } + + static Byte toByte(Object obj, String fieldName) { + var shortValue = toNumber(obj, fieldName).shortValue(); + checkByteBounds(shortValue); + + return (byte) shortValue; + } + private static Number toNumber(Object obj, String fieldName) { if (obj instanceof Number == false) { throw new IllegalArgumentException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java index ecd3125e228c9..e0cf5643f3d05 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java @@ -23,7 +23,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class CompletionResponseParser extends BaseCustomResponseParser { +public class CompletionResponseParser extends BaseCustomResponseParser { public static final String NAME = "completion_response_parser"; public static final String COMPLETION_PARSER_RESULT = "completion_result"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java index 3a421307d76a8..01fb7c79d7353 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java @@ -11,9 +11,17 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; public interface CustomResponseParser extends ToXContentFragment, NamedWriteable { InferenceServiceResults parse(HttpResult response) throws IOException; + + /** + * Returns the configured embedding type for this response parser. This should be overridden for text embedding parsers. + */ + default CustomServiceEmbeddingType getEmbeddingType() { + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java index 0a4c2c42b8c79..8c182a5f4b8a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class RerankResponseParser extends BaseCustomResponseParser { +public class RerankResponseParser extends BaseCustomResponseParser { public static final String NAME = "rerank_response_parser"; public static final String RERANK_PARSER_SCORE = "relevance_score"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java index 5e21bb018c1fe..85046eb86c9a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class SparseEmbeddingResponseParser extends BaseCustomResponseParser { +public class SparseEmbeddingResponseParser extends BaseCustomResponseParser { public static final String NAME = "sparse_embedding_response_parser"; public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index b5b0a191f3c4e..51f7ef29be666 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -7,34 +7,42 @@ package org.elasticsearch.xpack.inference.services.custom.response; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.MapPathExtractor; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class TextEmbeddingResponseParser extends BaseCustomResponseParser { +public class TextEmbeddingResponseParser extends BaseCustomResponseParser { public static final String NAME = "text_embedding_response_parser"; public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings"; - - private final String textEmbeddingsPath; + public static final String EMBEDDING_TYPE = "embedding_type"; public static TextEmbeddingResponseParser fromMap( Map responseParserMap, String scope, ValidationException validationException ) { + var jsonParserScope = String.join(".", scope, JSON_PARSER); var path = extractRequiredString( responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, @@ -42,45 +50,81 @@ public static TextEmbeddingResponseParser fromMap( validationException ); + var embeddingType = Objects.requireNonNullElse( + extractOptionalEnum( + responseParserMap, + EMBEDDING_TYPE, + jsonParserScope, + CustomServiceEmbeddingType::fromString, + EnumSet.allOf(CustomServiceEmbeddingType.class), + validationException + ), + CustomServiceEmbeddingType.FLOAT + ); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new TextEmbeddingResponseParser(path); + return new TextEmbeddingResponseParser(path, embeddingType); } - public TextEmbeddingResponseParser(String textEmbeddingsPath) { + private final String textEmbeddingsPath; + private final CustomServiceEmbeddingType embeddingType; + + public TextEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) { this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath); + this.embeddingType = Objects.requireNonNull(embeddingType); } public TextEmbeddingResponseParser(StreamInput in) throws IOException { this.textEmbeddingsPath = in.readString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { + this.embeddingType = in.readEnum(CustomServiceEmbeddingType.class); + } else { + this.embeddingType = CustomServiceEmbeddingType.FLOAT; + } } public void writeTo(StreamOutput out) throws IOException { out.writeString(textEmbeddingsPath); + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { + out.writeEnum(embeddingType); + } } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(JSON_PARSER); { builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath); + builder.field(EMBEDDING_TYPE, embeddingType.toString()); } builder.endObject(); return builder; } + // Default for testing + String getTextEmbeddingsPath() { + return textEmbeddingsPath; + } + + @Override + public CustomServiceEmbeddingType getEmbeddingType() { + return embeddingType; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o; - return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath); + return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) && Objects.equals(embeddingType, that.embeddingType); } @Override public int hashCode() { - return Objects.hash(textEmbeddingsPath); + return Objects.hash(textEmbeddingsPath, embeddingType); } @Override @@ -89,17 +133,16 @@ public String getWriteableName() { } @Override - protected TextEmbeddingFloatResults transform(Map map) { + protected InferenceServiceResults transform(Map map) { var extractedResult = MapPathExtractor.extract(map, textEmbeddingsPath); var mapResultsList = validateList(extractedResult.extractedObject(), extractedResult.getArrayFieldName(0)); - var embeddings = new ArrayList(mapResultsList.size()); + var embeddingConverter = createEmbeddingConverter(embeddingType); for (int i = 0; i < mapResultsList.size(); i++) { try { var entry = mapResultsList.get(i); - var embeddingsAsListFloats = convertToListOfFloats(entry, extractedResult.getArrayFieldName(1)); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + embeddingConverter.toEmbedding(entry, extractedResult.getArrayFieldName(1)); } catch (Exception e) { throw new IllegalArgumentException( Strings.format("Failed to parse text embedding entry [%d], error: %s", i, e.getMessage()), @@ -108,6 +151,74 @@ protected TextEmbeddingFloatResults transform(Map map) { } } - return new TextEmbeddingFloatResults(embeddings); + return embeddingConverter.getResults(); + } + + private static EmbeddingConverter createEmbeddingConverter(CustomServiceEmbeddingType embeddingType) { + return switch (embeddingType) { + case FLOAT -> new FloatEmbeddings(); + case BYTE -> new ByteEmbeddings(); + case BINARY, BIT -> new BitEmbeddings(); + }; + } + + private interface EmbeddingConverter { + void toEmbedding(Object entry, String fieldName); + + InferenceServiceResults getResults(); + } + + private static class FloatEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + FloatEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName); + embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + } + + public TextEmbeddingFloatResults getResults() { + return new TextEmbeddingFloatResults(embeddings); + } + } + + private static class ByteEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + ByteEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var convertedEmbeddings = convertToListOfBytes(entry, fieldName); + this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + } + + public TextEmbeddingByteResults getResults() { + return new TextEmbeddingByteResults(embeddings); + } + } + + private static class BitEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + BitEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var convertedEmbeddings = convertToListOfBits(entry, fieldName); + this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + } + + public TextEmbeddingBitResults getResults() { + return new TextEmbeddingBitResults(embeddings); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index d98cb3d90a0e1..1c5c13e2086c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -10,7 +10,6 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -93,7 +92,10 @@ public static CustomModel createModel( } public static CustomModel getTestModel() { - return getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")); + return getTestModel( + TaskType.TEXT_EMBEDDING, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) + ); } public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser) { @@ -108,12 +110,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r String requestContentString = "\"input\":\"${input}\""; CustomServiceSettings serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, headers, QueryParameters.EMPTY, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index d25db835fae04..65d9de30576ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -60,7 +60,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { var requestContentString = randomAlphaOfLength(10); var responseJsonParser = switch (taskType) { - case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser( "$.result.sparse_embeddings[*].embedding[*].token_id", "$.result.sparse_embeddings[*].embedding[*].weights" @@ -77,12 +77,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - similarityMeasure, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, dims, maxInputTokens), url, headers, queryParameters, @@ -105,7 +100,7 @@ public void testFromMap() { var queryParameters = List.of(List.of("key", "value")); String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -138,20 +133,14 @@ public void testFromMap() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), @@ -165,11 +154,246 @@ public void testFromMap() { ); } + public void testFromMap_EmbeddingType_Bit() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BIT.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + + public void testFromMap_EmbeddingType_Binary() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BINARY.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + + public void testFromMap_EmbeddingType_Byte() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BYTE.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + + assertThat(settings.elementType(), is(DenseVectorFieldMapper.ElementType.BYTE)); + } + + public void testFromMap_Completion_NoEmbeddingType() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new CompletionResponseParser("$.result.text"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.COMPLETION + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + assertNull(settings.elementType()); + } + + public void testFromMap_Completion_ThrowsWhenEmbeddingIsIncludedInMap() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + CompletionResponseParser.COMPLETION_PARSER_RESULT, + "$.result.text", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + "byte" + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.COMPLETION + ) + ); + + assertThat( + exception.getMessage(), + is( + "Configuration contains unknown settings [{embedding_type=byte}] while parsing field [json_parser] " + + "for settings [custom_service_settings]" + ) + ); + } + public void testFromMap_WithOptionalsNotSpecified() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -190,8 +414,7 @@ public void testFromMap_WithOptionalsNotSpecified() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( @@ -222,7 +445,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -251,20 +474,14 @@ public void testFromMap_RemovesNullValues_FromMaps() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, Map.of("value", "abc"), null, @@ -311,7 +528,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -358,7 +575,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -396,7 +613,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat(exception.getMessage(), is("Validation Failed: 1: [service_settings] does not contain the required setting [request];")); @@ -428,7 +645,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -471,7 +688,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { var exception = expectThrows( ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -511,7 +728,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { var exception = expectThrows( ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -549,7 +766,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { var exception = expectThrows( IllegalArgumentException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION) ); assertThat(exception.getMessage(), is("Invalid task type received [chat_completion] while constructing response parser")); @@ -562,7 +779,53 @@ public void testXContent() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": "string", + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" + } + }, + "input_type": { + "translation": {}, + "default": "" + }, + "rate_limit": { + "requests_per_minute": 10000 + }, + "batch_size": 10 + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_Rerank() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new RerankResponseParser( + "$.result.reranked_results[*].relevance_score", + "$.result.reranked_results[*].index", + "$.result.reranked_results[*].document_text" + ), null ); @@ -579,7 +842,9 @@ public void testXContent() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "relevance_score": "$.result.reranked_results[*].relevance_score", + "reranked_index": "$.result.reranked_results[*].index", + "document_text": "$.result.reranked_results[*].document_text" } }, "input_type": { @@ -603,7 +868,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, null, new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") @@ -622,7 +887,8 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" } }, "input_type": { @@ -649,7 +915,7 @@ public void testXContent_BatchSize11() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, 11, InputTypeTranslator.EMPTY_TRANSLATOR @@ -668,7 +934,8 @@ public void testXContent_BatchSize11() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" } }, "input_type": { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 6ddb4ff71eeb3..cc1bb4471c0a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -223,7 +222,7 @@ private static Map createSecretSettingsMap() { private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel( similarityMeasure, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), "http://www.abc.com" ); } @@ -244,7 +243,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456), url, Map.of("key", "value"), QueryParameters.EMPTY, @@ -271,7 +270,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456), url, Map.of("key", "value"), QueryParameters.EMPTY, @@ -318,7 +317,10 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson)); - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -373,7 +375,10 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -653,7 +658,7 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = createInternalEmbeddingModel( SimilarityMeasure.DOT_PRODUCT, - new TextEmbeddingResponseParser("$.data[*].embedding"), + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer), ChunkingSettingsTests.createRandomChunkingSettings(), 2 @@ -738,7 +743,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); String responseJson = """ { "object": "list", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index d1f606daef529..bb83897f27551 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -23,6 +22,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; @@ -56,17 +56,12 @@ public void testCreateRequest() throws IOException { """; var serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -129,7 +124,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx ) ), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -182,17 +177,12 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws """; var serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index 608fdb4d314c3..67bfe25fbdb6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; @@ -61,7 +62,7 @@ public void testFromTextEmbeddingResponse() throws IOException { var model = CustomModelTests.getTestModel( TaskType.TEXT_EMBEDDING, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) ); var request = new CustomRequest( EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java index 82ddfa618d3b7..6bb6da009e27c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -16,9 +16,12 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,6 +29,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE; +import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -33,18 +38,25 @@ public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { public static TextEmbeddingResponseParser createRandom() { - return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5)); + return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values())); } public void testFromMap() { var validation = new ValidationException(); var parser = TextEmbeddingResponseParser.fromMap( - new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")), + new HashMap<>( + Map.of( + TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result[*].embeddings", + EMBEDDING_TYPE, + CustomServiceEmbeddingType.BIT.toString() + ) + ), "scope", validation ); - assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings"))); + assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT))); } public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { @@ -61,7 +73,7 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { } public void testToXContent() throws IOException { - var entity = new TextEmbeddingResponseParser("$.result.path"); + var entity = new TextEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); { @@ -74,7 +86,8 @@ public void testToXContent() throws IOException { var expected = XContentHelper.stripWhitespace(""" { "json_parser": { - "text_embeddings": "$.result.path" + "text_embeddings": "$.result.path", + "embedding_type": "binary" } } """); @@ -104,7 +117,7 @@ public void testParse() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -115,6 +128,66 @@ public void testParse() throws IOException { ); } + public void testParseByte() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1, + -2 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE); + TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + } + + public void testParseBit() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1, + -2 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT); + TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + } + public void testParse_MultipleEmbeddings() throws IOException { String responseJson = """ { @@ -145,7 +218,7 @@ public void testParse_MultipleEmbeddings() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -193,7 +266,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -227,7 +300,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -244,6 +317,9 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { @Override protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) { + if (version.before(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE)) { + return new TextEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT); + } return instance; }