From ad091b67a25a126d0826c0ccbeaff548f274ecbe Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 26 Jun 2025 16:55:57 -0400 Subject: [PATCH 1/2] Adding embedding type --- .../org/elasticsearch/TransportVersions.java | 2 + .../inference/services/ServiceUtils.java | 6 + .../services/custom/CustomService.java | 7 +- .../custom/CustomServiceEmbeddingType.java | 50 +++++ .../custom/CustomServiceSettings.java | 63 ++++-- .../response/BaseCustomResponseParser.java | 21 +- .../response/CompletionResponseParser.java | 2 +- .../custom/response/CustomResponseParser.java | 8 + .../custom/response/RerankResponseParser.java | 2 +- .../SparseEmbeddingResponseParser.java | 2 +- .../response/TextEmbeddingResponseParser.java | 136 +++++++++++-- .../services/custom/CustomModelTests.java | 8 +- .../custom/CustomServiceSettingsTests.java | 185 ++++++++++++++++-- .../services/custom/CustomServiceTests.java | 24 ++- .../custom/request/CustomRequestTests.java | 12 +- .../response/CustomResponseEntityTests.java | 3 +- .../TextEmbeddingResponseParserTests.java | 96 ++++++++- 17 files changed, 542 insertions(+), 85 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index b949cb62f6a06..cae1edc62a2c3 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -201,6 +201,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53); public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19 = def(8_841_0_56); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); @@ -310,6 +311,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00); public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_106_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index a3e9d69654f1f..adbec49328804 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -1085,5 +1085,11 @@ public static void validateInputTypeAgainstAllowlist( } } + public static void checkByteBounds(short value) { + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte"); + } + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index deb6e17ec5311..d04a2bcb7960a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -333,12 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - similarityToUse, - embeddingSize, - serviceSettings.getMaxInputTokens(), - serviceSettings.elementType() - ), + new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens(), null), serviceSettings.getUrl(), serviceSettings.getHeaders(), serviceSettings.getQueryParameters(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java new file mode 100644 index 0000000000000..143a0b2dcb6fe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceEmbeddingType.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +import java.util.Locale; + +public enum CustomServiceEmbeddingType { + /** + * Use this when you want to get back the default float embeddings. + */ + FLOAT(DenseVectorFieldMapper.ElementType.FLOAT), + /** + * Use this when you want to get back signed int8 embeddings. + */ + BYTE(DenseVectorFieldMapper.ElementType.BYTE), + /** + * Use this when you want to get back binary embeddings. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT); + + private final DenseVectorFieldMapper.ElementType elementType; + + CustomServiceEmbeddingType(DenseVectorFieldMapper.ElementType elementType) { + this.elementType = elementType; + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + public DenseVectorFieldMapper.ElementType toElementType() { + return elementType; + } + + public static CustomServiceEmbeddingType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 83048120bc545..aabd746fe5aee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -137,19 +137,14 @@ public static CustomServiceSettings fromMap( ); } - public record TextEmbeddingSettings( - @Nullable SimilarityMeasure similarityMeasure, - @Nullable Integer dimensions, - @Nullable Integer maxInputTokens, - @Nullable DenseVectorFieldMapper.ElementType elementType - ) implements ToXContentFragment, Writeable { + public static class TextEmbeddingSettings implements ToXContentFragment, Writeable { // This specifies float for the element type but null for all other settings public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( null, null, null, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ); // This refers to settings that are not related to the text embedding task type (all the settings should be null) public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); @@ -162,16 +157,33 @@ public static TextEmbeddingSettings fromMap(Map map, TaskType ta SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - return new TextEmbeddingSettings(similarity, dims, maxInputTokens, DenseVectorFieldMapper.ElementType.FLOAT); + return new TextEmbeddingSettings(similarity, dims, maxInputTokens, null); + } + + private final SimilarityMeasure similarityMeasure; + private final Integer dimensions; + private final Integer maxInputTokens; + + public TextEmbeddingSettings( + @Nullable SimilarityMeasure similarityMeasure, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable CustomServiceEmbeddingType embeddingType + ) { + this.similarityMeasure = similarityMeasure; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; } public TextEmbeddingSettings(StreamInput in) throws IOException { - this( - in.readOptionalEnum(SimilarityMeasure.class), - in.readOptionalVInt(), - in.readOptionalVInt(), - in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) - ); + this.similarityMeasure = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) + && in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19) == false) { + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class); + } } @Override @@ -179,7 +191,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(similarityMeasure); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - out.writeOptionalEnum(elementType); + + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) + && out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19) == false) { + out.writeOptionalEnum(null); + } } @Override @@ -193,8 +209,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (maxInputTokens != null) { builder.field(MAX_INPUT_TOKENS, maxInputTokens); } + return builder; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + TextEmbeddingSettings that = (TextEmbeddingSettings) o; + return similarityMeasure == that.similarityMeasure + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(similarityMeasure, dimensions, maxInputTokens); + } } private final TextEmbeddingSettings textEmbeddingSettings; @@ -300,7 +331,7 @@ public Integer dimensions() { @Override public DenseVectorFieldMapper.ElementType elementType() { - return textEmbeddingSettings.elementType; + return responseJsonParser.getEmbeddingType().toElementType(); } public Integer getMaxInputTokens() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java index 99b035ef056c7..ac9a223de17c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java @@ -22,7 +22,9 @@ import java.util.Objects; import java.util.function.BiFunction; -public abstract class BaseCustomResponseParser implements CustomResponseParser { +import static org.elasticsearch.xpack.inference.services.ServiceUtils.checkByteBounds; + +public abstract class BaseCustomResponseParser implements CustomResponseParser { @Override public InferenceServiceResults parse(HttpResult response) throws IOException { @@ -36,7 +38,7 @@ public InferenceServiceResults parse(HttpResult response) throws IOException { } } - protected abstract T transform(Map extractedField); + protected abstract InferenceServiceResults transform(Map extractedField); static List validateList(Object obj, String fieldName) { validateNonNull(obj, fieldName); @@ -97,6 +99,21 @@ static Float toFloat(Object obj, String fieldName) { return toNumber(obj, fieldName).floatValue(); } + static List convertToListOfBits(Object obj, String fieldName) { + return convertToListOfBytes(obj, fieldName); + } + + static List convertToListOfBytes(Object obj, String fieldName) { + return castList(validateList(obj, fieldName), BaseCustomResponseParser::toByte, fieldName); + } + + static Byte toByte(Object obj, String fieldName) { + var shortValue = toNumber(obj, fieldName).shortValue(); + checkByteBounds(shortValue); + + return (byte) shortValue; + } + private static Number toNumber(Object obj, String fieldName) { if (obj instanceof Number == false) { throw new IllegalArgumentException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java index ecd3125e228c9..e0cf5643f3d05 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java @@ -23,7 +23,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class CompletionResponseParser extends BaseCustomResponseParser { +public class CompletionResponseParser extends BaseCustomResponseParser { public static final String NAME = "completion_response_parser"; public static final String COMPLETION_PARSER_RESULT = "completion_result"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java index 3a421307d76a8..01fb7c79d7353 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java @@ -11,9 +11,17 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; public interface CustomResponseParser extends ToXContentFragment, NamedWriteable { InferenceServiceResults parse(HttpResult response) throws IOException; + + /** + * Returns the configured embedding type for this response parser. This should be overridden for text embedding parsers. + */ + default CustomServiceEmbeddingType getEmbeddingType() { + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java index 0a4c2c42b8c79..8c182a5f4b8a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class RerankResponseParser extends BaseCustomResponseParser { +public class RerankResponseParser extends BaseCustomResponseParser { public static final String NAME = "rerank_response_parser"; public static final String RERANK_PARSER_SCORE = "relevance_score"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java index 5e21bb018c1fe..85046eb86c9a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class SparseEmbeddingResponseParser extends BaseCustomResponseParser { +public class SparseEmbeddingResponseParser extends BaseCustomResponseParser { public static final String NAME = "sparse_embedding_response_parser"; public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index b5b0a191f3c4e..54ed46336fb60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -7,34 +7,42 @@ package org.elasticsearch.xpack.inference.services.custom.response; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.common.MapPathExtractor; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; -public class TextEmbeddingResponseParser extends BaseCustomResponseParser { +public class TextEmbeddingResponseParser extends BaseCustomResponseParser { public static final String NAME = "text_embedding_response_parser"; public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings"; - - private final String textEmbeddingsPath; + public static final String EMBEDDING_TYPE = "embedding_type"; public static TextEmbeddingResponseParser fromMap( Map responseParserMap, String scope, ValidationException validationException ) { + var jsonParserScope = String.join(".", scope, JSON_PARSER); var path = extractRequiredString( responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, @@ -42,45 +50,82 @@ public static TextEmbeddingResponseParser fromMap( validationException ); + var embeddingType = Objects.requireNonNullElse( + extractOptionalEnum( + responseParserMap, + EMBEDDING_TYPE, + jsonParserScope, + CustomServiceEmbeddingType::fromString, + EnumSet.allOf(CustomServiceEmbeddingType.class), + validationException + ), + CustomServiceEmbeddingType.FLOAT + ); + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new TextEmbeddingResponseParser(path); + return new TextEmbeddingResponseParser(path, embeddingType); } - public TextEmbeddingResponseParser(String textEmbeddingsPath) { + private final String textEmbeddingsPath; + private final CustomServiceEmbeddingType embeddingType; + + public TextEmbeddingResponseParser(String textEmbeddingsPath, CustomServiceEmbeddingType embeddingType) { this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath); + this.embeddingType = Objects.requireNonNull(embeddingType); } public TextEmbeddingResponseParser(StreamInput in) throws IOException { this.textEmbeddingsPath = in.readString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19)) { + this.embeddingType = in.readEnum(CustomServiceEmbeddingType.class); + } else { + this.embeddingType = CustomServiceEmbeddingType.FLOAT; + } } public void writeTo(StreamOutput out) throws IOException { out.writeString(textEmbeddingsPath); + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19)) { + out.writeEnum(embeddingType); + } } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(JSON_PARSER); { builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath); + builder.field(EMBEDDING_TYPE, embeddingType.toString()); } builder.endObject(); return builder; } + // For testing + String getTextEmbeddingsPath() { + return textEmbeddingsPath; + } + + public CustomServiceEmbeddingType getEmbeddingType() { + return embeddingType; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o; - return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath); + return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) && Objects.equals(embeddingType, that.embeddingType); } @Override public int hashCode() { - return Objects.hash(textEmbeddingsPath); + return Objects.hash(textEmbeddingsPath, embeddingType); } @Override @@ -89,17 +134,16 @@ public String getWriteableName() { } @Override - protected TextEmbeddingFloatResults transform(Map map) { + protected InferenceServiceResults transform(Map map) { var extractedResult = MapPathExtractor.extract(map, textEmbeddingsPath); var mapResultsList = validateList(extractedResult.extractedObject(), extractedResult.getArrayFieldName(0)); - var embeddings = new ArrayList(mapResultsList.size()); + var embeddingConverter = createEmbeddingConverter(embeddingType); for (int i = 0; i < mapResultsList.size(); i++) { try { var entry = mapResultsList.get(i); - var embeddingsAsListFloats = convertToListOfFloats(entry, extractedResult.getArrayFieldName(1)); - embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + embeddingConverter.toEmbedding(entry, extractedResult.getArrayFieldName(1)); } catch (Exception e) { throw new IllegalArgumentException( Strings.format("Failed to parse text embedding entry [%d], error: %s", i, e.getMessage()), @@ -108,6 +152,74 @@ protected TextEmbeddingFloatResults transform(Map map) { } } - return new TextEmbeddingFloatResults(embeddings); + return embeddingConverter.getResults(); + } + + private static EmbeddingConverter createEmbeddingConverter(CustomServiceEmbeddingType embeddingType) { + return switch (embeddingType) { + case FLOAT -> new FloatEmbeddings(); + case BYTE -> new ByteEmbeddings(); + case BINARY, BIT -> new BitEmbeddings(); + }; + } + + private interface EmbeddingConverter { + void toEmbedding(Object entry, String fieldName); + + InferenceServiceResults getResults(); + } + + private static class FloatEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + FloatEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var embeddingsAsListFloats = convertToListOfFloats(entry, fieldName); + embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + } + + public TextEmbeddingFloatResults getResults() { + return new TextEmbeddingFloatResults(embeddings); + } + } + + private static class ByteEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + ByteEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var convertedEmbeddings = convertToListOfBytes(entry, fieldName); + this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + } + + public TextEmbeddingByteResults getResults() { + return new TextEmbeddingByteResults(embeddings); + } + } + + private static class BitEmbeddings implements EmbeddingConverter { + + private final List embeddings; + + BitEmbeddings() { + this.embeddings = new ArrayList<>(); + } + + public void toEmbedding(Object entry, String fieldName) { + var convertedEmbeddings = convertToListOfBits(entry, fieldName); + this.embeddings.add(TextEmbeddingByteResults.Embedding.of(convertedEmbeddings)); + } + + public TextEmbeddingBitResults getResults() { + return new TextEmbeddingBitResults(embeddings); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index d98cb3d90a0e1..413ce43157679 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -10,7 +10,6 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -93,7 +92,10 @@ public static CustomModel createModel( } public static CustomModel getTestModel() { - return getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")); + return getTestModel( + TaskType.TEXT_EMBEDDING, + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) + ); } public static CustomModel getTestModel(TaskType taskType, CustomResponseParser responseParser) { @@ -112,7 +114,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ), url, headers, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index d25db835fae04..c460af6810f4b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -60,7 +59,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { var requestContentString = randomAlphaOfLength(10); var responseJsonParser = switch (taskType) { - case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + case TEXT_EMBEDDING -> new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); case SPARSE_EMBEDDING -> new SparseEmbeddingResponseParser( "$.result.sparse_embeddings[*].embedding[*].token_id", "$.result.sparse_embeddings[*].embedding[*].weights" @@ -77,12 +76,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - similarityMeasure, - dims, - maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, dims, maxInputTokens, CustomServiceEmbeddingType.FLOAT), url, headers, queryParameters, @@ -105,7 +99,7 @@ public void testFromMap() { var queryParameters = List.of(List.of("key", "value")); String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -150,7 +144,7 @@ public void testFromMap() { SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ), url, headers, @@ -165,11 +159,161 @@ public void testFromMap() { ); } + public void testFromMap_EmbeddingType_Bit() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BIT); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BIT.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BIT), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + + public void testFromMap_EmbeddingType_Binary() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BINARY); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BINARY.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BINARY), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + + public void testFromMap_EmbeddingType_Byte() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.BYTE); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + CustomServiceEmbeddingType.BYTE.toString() + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.TEXT_EMBEDDING, + "inference_id" + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BYTE), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + public void testFromMap_WithOptionalsNotSpecified() { String url = "http://www.abc.com"; String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -222,7 +366,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { String requestContentString = "request body"; - var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"); + var responseParser = new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT); var settings = CustomServiceSettings.fromMap( new HashMap<>( @@ -263,7 +407,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ), url, Map.of("value", "abc"), @@ -562,7 +706,7 @@ public void testXContent() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null ); @@ -579,7 +723,8 @@ public void testXContent() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" } }, "input_type": { @@ -603,7 +748,7 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, null, new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default") @@ -622,7 +767,8 @@ public void testXContent_WithInputTypeTranslationValues() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" } }, "input_type": { @@ -649,7 +795,7 @@ public void testXContent_BatchSize11() throws IOException { Map.of("key", "value"), null, "string", - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), null, 11, InputTypeTranslator.EMPTY_TRANSLATOR @@ -668,7 +814,8 @@ public void testXContent_BatchSize11() throws IOException { "request": "string", "response": { "json_parser": { - "text_embeddings": "$.result.embeddings[*].embedding" + "text_embeddings": "$.result.embeddings[*].embedding", + "embedding_type": "float" } }, "input_type": { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 6ddb4ff71eeb3..2a64c0bab6b32 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -223,7 +222,7 @@ private static Map createSecretSettingsMap() { private static CustomModel createInternalEmbeddingModel(SimilarityMeasure similarityMeasure) { return createInternalEmbeddingModel( similarityMeasure, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT), "http://www.abc.com" ); } @@ -244,7 +243,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, CustomServiceEmbeddingType.FLOAT), url, Map.of("key", "value"), QueryParameters.EMPTY, @@ -271,7 +270,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, CustomServiceEmbeddingType.FLOAT), url, Map.of("key", "value"), QueryParameters.EMPTY, @@ -318,7 +317,10 @@ public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOEx webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson)); - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -373,7 +375,10 @@ public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOExcep webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -653,7 +658,7 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = createInternalEmbeddingModel( SimilarityMeasure.DOT_PRODUCT, - new TextEmbeddingResponseParser("$.data[*].embedding"), + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), getUrl(webServer), ChunkingSettingsTests.createRandomChunkingSettings(), 2 @@ -738,7 +743,10 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + var model = createInternalEmbeddingModel( + new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); String responseJson = """ { "object": "list", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index d1f606daef529..94c40fd71d5c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -23,6 +22,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator; @@ -60,13 +60,13 @@ public void testCreateRequest() throws IOException { SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -129,7 +129,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx ) ), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000), null, new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default") @@ -186,13 +186,13 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, - DenseVectorFieldMapper.ElementType.FLOAT + CustomServiceEmbeddingType.FLOAT ), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, - new TextEmbeddingResponseParser("$.result.embeddings"), + new TextEmbeddingResponseParser("$.result.embeddings", CustomServiceEmbeddingType.FLOAT), new RateLimitSettings(10_000) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index 608fdb4d314c3..67bfe25fbdb6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters; import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters; @@ -61,7 +62,7 @@ public void testFromTextEmbeddingResponse() throws IOException { var model = CustomModelTests.getTestModel( TaskType.TEXT_EMBEDDING, - new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) ); var request = new CustomRequest( EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java index 82ddfa618d3b7..c898af32157f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -16,9 +16,12 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceEmbeddingType; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,6 +29,9 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE; +import static org.elasticsearch.TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19; +import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.EMBEDDING_TYPE; import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -33,18 +39,25 @@ public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { public static TextEmbeddingResponseParser createRandom() { - return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5)); + return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5), randomFrom(CustomServiceEmbeddingType.values())); } public void testFromMap() { var validation = new ValidationException(); var parser = TextEmbeddingResponseParser.fromMap( - new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")), + new HashMap<>( + Map.of( + TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result[*].embeddings", + EMBEDDING_TYPE, + CustomServiceEmbeddingType.BIT.toString() + ) + ), "scope", validation ); - assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings"))); + assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings", CustomServiceEmbeddingType.BIT))); } public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { @@ -61,7 +74,7 @@ public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { } public void testToXContent() throws IOException { - var entity = new TextEmbeddingResponseParser("$.result.path"); + var entity = new TextEmbeddingResponseParser("$.result.path", CustomServiceEmbeddingType.BINARY); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); { @@ -74,7 +87,8 @@ public void testToXContent() throws IOException { var expected = XContentHelper.stripWhitespace(""" { "json_parser": { - "text_embeddings": "$.result.path" + "text_embeddings": "$.result.path", + "embedding_type": "binary" } } """); @@ -104,7 +118,7 @@ public void testParse() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -115,6 +129,66 @@ public void testParse() throws IOException { ); } + public void testParseByte() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1, + -2 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BYTE); + TextEmbeddingByteResults parsedResults = (TextEmbeddingByteResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + } + + public void testParseBit() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1, + -2 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.BIT); + TextEmbeddingBitResults parsedResults = (TextEmbeddingBitResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new TextEmbeddingBitResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { 1, -2 }))))); + } + public void testParse_MultipleEmbeddings() throws IOException { String responseJson = """ { @@ -145,7 +219,7 @@ public void testParse_MultipleEmbeddings() throws IOException { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -193,7 +267,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -227,7 +301,7 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { } """; - var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var parser = new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT); var exception = expectThrows( IllegalArgumentException.class, () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) @@ -244,6 +318,10 @@ public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { @Override protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) { + if (version.before(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE) + && version.isPatchFrom(ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE_8_19) == false) { + return new TextEmbeddingResponseParser(instance.getTextEmbeddingsPath(), CustomServiceEmbeddingType.FLOAT); + } return instance; } From 2a9330b8727b71cdf8e499e95b5eb9fcf29c4a2b Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 27 Jun 2025 11:50:46 -0400 Subject: [PATCH 2/2] Adding more tests and cleaning up --- .../services/custom/CustomModel.java | 4 +- .../services/custom/CustomService.java | 2 +- .../custom/CustomServiceSettings.java | 28 +-- .../response/TextEmbeddingResponseParser.java | 3 +- .../services/custom/CustomModelTests.java | 7 +- .../custom/CustomServiceSettingsTests.java | 190 ++++++++++++++---- .../services/custom/CustomServiceTests.java | 4 +- .../custom/request/CustomRequestTests.java | 14 +- 8 files changed, 176 insertions(+), 76 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java index b23f515055b9d..384acefc50e3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -46,7 +46,7 @@ public CustomModel( inferenceId, taskType, service, - CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomServiceSettings.fromMap(serviceSettings, context, taskType), CustomTaskSettings.fromMap(taskSettings), CustomSecretSettings.fromMap(secrets) ); @@ -66,7 +66,7 @@ public CustomModel( inferenceId, taskType, service, - CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomServiceSettings.fromMap(serviceSettings, context, taskType), CustomTaskSettings.fromMap(taskSettings), CustomSecretSettings.fromMap(secrets), chunkingSettings diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index d04a2bcb7960a..3af0a8d55dfb3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -333,7 +333,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens(), null), + new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens()), serviceSettings.getUrl(), serviceSettings.getHeaders(), serviceSettings.getQueryParameters(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 748271401a66e..89d209b1099b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -66,12 +66,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10; - public static CustomServiceSettings fromMap( - Map map, - ConfigurationParseContext context, - TaskType taskType, - String inferenceId - ) { + public static CustomServiceSettings fromMap(Map map, ConfigurationParseContext context, TaskType taskType) { ValidationException validationException = new ValidationException(); var textEmbeddingSettings = TextEmbeddingSettings.fromMap(map, taskType, validationException); @@ -140,14 +135,9 @@ public static CustomServiceSettings fromMap( public static class TextEmbeddingSettings implements ToXContentFragment, Writeable { // This specifies float for the element type but null for all other settings - public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings( - null, - null, - null, - CustomServiceEmbeddingType.FLOAT - ); + public static final TextEmbeddingSettings DEFAULT_FLOAT = new TextEmbeddingSettings(null, null, null); // This refers to settings that are not related to the text embedding task type (all the settings should be null) - public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); + public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null); public static TextEmbeddingSettings fromMap(Map map, TaskType taskType, ValidationException validationException) { if (taskType != TaskType.TEXT_EMBEDDING) { @@ -157,7 +147,7 @@ public static TextEmbeddingSettings fromMap(Map map, TaskType ta SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - return new TextEmbeddingSettings(similarity, dims, maxInputTokens, null); + return new TextEmbeddingSettings(similarity, dims, maxInputTokens); } private final SimilarityMeasure similarityMeasure; @@ -167,8 +157,7 @@ public static TextEmbeddingSettings fromMap(Map map, TaskType ta public TextEmbeddingSettings( @Nullable SimilarityMeasure similarityMeasure, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens, - @Nullable CustomServiceEmbeddingType embeddingType + @Nullable Integer maxInputTokens ) { this.similarityMeasure = similarityMeasure; this.dimensions = dimensions; @@ -331,7 +320,12 @@ public Integer dimensions() { @Override public DenseVectorFieldMapper.ElementType elementType() { - return responseJsonParser.getEmbeddingType().toElementType(); + var embeddingType = responseJsonParser.getEmbeddingType(); + if (embeddingType != null) { + return embeddingType.toElementType(); + } + + return null; } public Integer getMaxInputTokens() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java index 54ed46336fb60..4a8fef2731764 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -106,11 +106,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - // For testing + // Default for testing String getTextEmbeddingsPath() { return textEmbeddingsPath; } + @Override public CustomServiceEmbeddingType getEmbeddingType() { return embeddingType; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index 413ce43157679..1c5c13e2086c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -110,12 +110,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r String requestContentString = "\"input\":\"${input}\""; CustomServiceSettings serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - CustomServiceEmbeddingType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, headers, QueryParameters.EMPTY, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index c460af6810f4b..65d9de30576ff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -76,7 +77,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); return new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, dims, maxInputTokens, CustomServiceEmbeddingType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, dims, maxInputTokens), url, headers, queryParameters, @@ -132,20 +133,14 @@ public void testFromMap() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - CustomServiceEmbeddingType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), @@ -189,15 +184,14 @@ public void testFromMap_EmbeddingType_Bit() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BIT), + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), url, Map.of(), null, @@ -239,15 +233,14 @@ public void testFromMap_EmbeddingType_Binary() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BINARY), + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), url, Map.of(), null, @@ -289,15 +282,14 @@ public void testFromMap_EmbeddingType_Byte() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(null, null, null, CustomServiceEmbeddingType.BYTE), + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), url, Map.of(), null, @@ -307,6 +299,94 @@ public void testFromMap_EmbeddingType_Byte() { ) ) ); + + assertThat(settings.elementType(), is(DenseVectorFieldMapper.ElementType.BYTE)); + } + + public void testFromMap_Completion_NoEmbeddingType() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var responseParser = new CompletionResponseParser("$.result.text"); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>(Map.of(CompletionResponseParser.COMPLETION_PARSER_RESULT, "$.result.text")) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.COMPLETION + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(null, null, null), + url, + Map.of(), + null, + requestContentString, + responseParser, + new RateLimitSettings(10_000) + ) + ) + ); + assertNull(settings.elementType()); + } + + public void testFromMap_Completion_ThrowsWhenEmbeddingIsIncludedInMap() { + String url = "http://www.abc.com"; + String requestContentString = "request body"; + + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + CustomServiceSettings.URL, + url, + CustomServiceSettings.REQUEST, + requestContentString, + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + CompletionResponseParser.COMPLETION_PARSER_RESULT, + "$.result.text", + TextEmbeddingResponseParser.EMBEDDING_TYPE, + "byte" + ) + ) + ) + ) + ) + ), + ConfigurationParseContext.REQUEST, + TaskType.COMPLETION + ) + ); + + assertThat( + exception.getMessage(), + is( + "Configuration contains unknown settings [{embedding_type=byte}] while parsing field [json_parser] " + + "for settings [custom_service_settings]" + ) + ); } public void testFromMap_WithOptionalsNotSpecified() { @@ -334,8 +414,7 @@ public void testFromMap_WithOptionalsNotSpecified() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( @@ -395,20 +474,14 @@ public void testFromMap_RemovesNullValues_FromMaps() { ) ), ConfigurationParseContext.REQUEST, - TaskType.TEXT_EMBEDDING, - "inference_id" + TaskType.TEXT_EMBEDDING ); MatcherAssert.assertThat( settings, is( new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - CustomServiceEmbeddingType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), url, Map.of("value", "abc"), null, @@ -455,7 +528,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -502,7 +575,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -540,7 +613,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat(exception.getMessage(), is("Validation Failed: 1: [service_settings] does not contain the required setting [request];")); @@ -572,7 +645,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { var exception = expectThrows( ValidationException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -615,7 +688,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { var exception = expectThrows( ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -655,7 +728,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { var exception = expectThrows( ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING) ); assertThat( @@ -693,7 +766,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { var exception = expectThrows( IllegalArgumentException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION, "inference_id") + () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.CHAT_COMPLETION) ); assertThat(exception.getMessage(), is("Invalid task type received [chat_completion] while constructing response parser")); @@ -741,6 +814,53 @@ public void testXContent() throws IOException { assertThat(xContentResult, is(expected)); } + public void testXContent_Rerank() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new RerankResponseParser( + "$.result.reranked_results[*].relevance_score", + "$.result.reranked_results[*].index", + "$.result.reranked_results[*].document_text" + ), + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": "string", + "response": { + "json_parser": { + "relevance_score": "$.result.reranked_results[*].relevance_score", + "reranked_index": "$.result.reranked_results[*].index", + "document_text": "$.result.reranked_results[*].document_text" + } + }, + "input_type": { + "translation": {}, + "default": "" + }, + "rate_limit": { + "requests_per_minute": 10000 + }, + "batch_size": 10 + } + """); + + assertThat(xContentResult, is(expected)); + } + public void testXContent_WithInputTypeTranslationValues() throws IOException { var entity = new CustomServiceSettings( CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 2a64c0bab6b32..cc1bb4471c0a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -243,7 +243,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, CustomServiceEmbeddingType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456), url, Map.of("key", "value"), QueryParameters.EMPTY, @@ -270,7 +270,7 @@ private static CustomModel createInternalEmbeddingModel( TaskType.TEXT_EMBEDDING, CustomService.NAME, new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, CustomServiceEmbeddingType.FLOAT), + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456), url, Map.of("key", "value"), QueryParameters.EMPTY, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 94c40fd71d5c7..bb83897f27551 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -56,12 +56,7 @@ public void testCreateRequest() throws IOException { """; var serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - CustomServiceEmbeddingType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), @@ -182,12 +177,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws """; var serviceSettings = new CustomServiceSettings( - new CustomServiceSettings.TextEmbeddingSettings( - SimilarityMeasure.DOT_PRODUCT, - dims, - maxInputTokens, - CustomServiceEmbeddingType.FLOAT - ), + new CustomServiceSettings.TextEmbeddingSettings(SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens), "${url}", headers, new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),