diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 29e0e41e856b1..494283835ca76 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -198,6 +198,7 @@ static TransportVersion def(int id) { public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50); public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -304,6 +305,7 @@ static TransportVersion def(int id) { public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00); public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java index 7c00b0a242f94..b23f515055b9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.custom; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -51,6 +52,27 @@ public CustomModel( ); } + public CustomModel( + String inferenceId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + @Nullable ChunkingSettings chunkingSettings, + ConfigurationParseContext context + ) { + this( + inferenceId, + taskType, + service, + CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomTaskSettings.fromMap(taskSettings), + CustomSecretSettings.fromMap(secrets), + chunkingSettings + ); + } + // should only be used for testing CustomModel( String inferenceId, @@ -67,6 +89,23 @@ public CustomModel( ); } + // should only be used for testing + CustomModel( + String inferenceId, + TaskType taskType, + String service, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings, + @Nullable ChunkingSettings chunkingSettings + ) { + this( + new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + protected CustomModel(CustomModel model, TaskSettings taskSettings) { super(model, taskSettings); rateLimitServiceSettings = model.rateLimitServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 77b852f43cd8f..873e5515647cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -27,6 +28,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -45,6 +48,7 @@ import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -81,12 +85,15 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + var chunkingSettings = extractChunkingSettings(config, taskType); + CustomModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, serviceSettingsMap, + chunkingSettings, ConfigurationParseContext.REQUEST ); @@ -100,6 +107,14 @@ public void parseRequestConfig( } } + private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return null; + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -125,7 +140,8 @@ private static CustomModel createModelWithoutLoggingDeprecations( TaskType taskType, Map serviceSettings, Map taskSettings, - @Nullable Map secretSettings + @Nullable Map secretSettings, + @Nullable ChunkingSettings chunkingSettings ) { return createModel( inferenceEntityId, @@ -133,6 +149,7 @@ private static CustomModel createModelWithoutLoggingDeprecations( serviceSettings, taskSettings, secretSettings, + chunkingSettings, ConfigurationParseContext.PERSISTENT ); } @@ -143,12 +160,13 @@ private static CustomModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, + @Nullable ChunkingSettings chunkingSettings, ConfigurationParseContext context ) { if (supportedTaskTypes.contains(taskType) == false) { throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } - return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context); } @Override @@ -162,7 +180,16 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); + var chunkingSettings = extractChunkingSettings(config, taskType); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + chunkingSettings + ); } @Override @@ -170,7 +197,16 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); + var chunkingSettings = extractChunkingSettings(config, taskType); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + chunkingSettings + ); } @Override @@ -211,7 +247,27 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + if (model instanceof CustomModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var customModel = (CustomModel) model; + var overriddenModel = CustomModel.of(customModel, taskSettings); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME); + var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + customModel.getServiceSettings().getBatchSize(), + customModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 8b8b270db3bd9..15caace9820fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; @@ -52,8 +53,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings { + public static final String NAME = "custom_service_settings"; public static final String URL = "url"; + public static final String BATCH_SIZE = "batch_size"; public static final String HEADERS = "headers"; public static final String REQUEST = "request"; public static final String RESPONSE = "response"; @@ -61,6 +64,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); + private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10; public static CustomServiceSettings fromMap( Map map, @@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap( context ); + var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException); + if (responseParserMap == null || jsonParserMap == null) { throw validationException; } @@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap( queryParams, requestContentString, responseJsonParser, - rateLimitSettings + rateLimitSettings, + batchSize ); } @@ -142,7 +149,6 @@ public record TextEmbeddingSettings( null, DenseVectorFieldMapper.ElementType.FLOAT ); - // This refers to settings that are not related to the text embedding task type (all the settings should be null) public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); @@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final String requestContentString; private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; + private final int batchSize; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -205,6 +212,19 @@ public CustomServiceSettings( String requestContentString, CustomResponseParser responseJsonParser, @Nullable RateLimitSettings rateLimitSettings + ) { + this(textEmbeddingSettings, url, headers, queryParameters, requestContentString, responseJsonParser, rateLimitSettings, null); + } + + public CustomServiceSettings( + TextEmbeddingSettings textEmbeddingSettings, + String url, + @Nullable Map headers, + @Nullable QueryParameters queryParameters, + String requestContentString, + CustomResponseParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings, + @Nullable Integer batchSize ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -213,6 +233,7 @@ public CustomServiceSettings( this.requestContentString = Objects.requireNonNull(requestContentString); this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -223,12 +244,20 @@ public CustomServiceSettings(StreamInput in) throws IOException { requestContentString = in.readString(); responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) && in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { // Read the error parsing fields for backwards compatibility in.readString(); in.readString(); } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) { + batchSize = in.readVInt(); + } else { + batchSize = DEFAULT_EMBEDDING_BATCH_SIZE; + } } @Override @@ -276,6 +305,10 @@ public CustomResponseParser getResponseJsonParser() { return responseJsonParser; } + public int getBatchSize() { + return batchSize; + } + @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; @@ -321,6 +354,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder rateLimitSettings.toXContent(builder, params); + builder.field(BATCH_SIZE, batchSize); + return builder; } @@ -343,12 +378,18 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(requestContentString); out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) && out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { // Write empty strings for backwards compatibility for the error parsing fields out.writeString(""); out.writeString(""); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) { + out.writeVInt(batchSize); + } } @Override @@ -362,7 +403,8 @@ public boolean equals(Object o) { && Objects.equals(queryParameters, that.queryParameters) && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(batchSize, that.batchSize); } @Override @@ -374,7 +416,8 @@ public int hashCode() { queryParameters, requestContentString, responseJsonParser, - rateLimitSettings + rateLimitSettings, + batchSize ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 3e2289e418f76..47346fc896baa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -131,7 +131,9 @@ public void testFromMap() { Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ) ) - ) + ), + CustomServiceSettings.BATCH_SIZE, + 11 ) ), ConfigurationParseContext.REQUEST, @@ -154,7 +156,8 @@ public void testFromMap() { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), requestContentString, responseParser, - new RateLimitSettings(10_000) + new RateLimitSettings(10_000), + 11 ) ) ); @@ -579,7 +582,46 @@ public void testXContent() throws IOException { }, "rate_limit": { "requests_per_minute": 10000 - } + }, + "batch_size": 10 + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_BatchSize11() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + 11 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": "string", + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + } + }, + "rate_limit": { + "requests_per_minute": 10000 + }, + "batch_size": 11 } """); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index dedc0d0e71ac9..dc82d71df6503 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -13,6 +13,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -23,9 +26,11 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; @@ -37,6 +42,8 @@ import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; import java.io.IOException; import java.util.EnumSet; @@ -45,6 +52,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT; @@ -52,6 +60,7 @@ import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -237,7 +246,7 @@ private static CustomModel createInternalEmbeddingModel( url, Map.of("key", "value"), QueryParameters.EMPTY, - "\"input\":\"${input}\"", + "{\"input\":${input}}", parser, new RateLimitSettings(10_000) ), @@ -246,6 +255,35 @@ private static CustomModel createInternalEmbeddingModel( ); } + private static CustomModel createInternalEmbeddingModel( + @Nullable SimilarityMeasure similarityMeasure, + TextEmbeddingResponseParser parser, + String url, + @Nullable ChunkingSettings chunkingSettings, + @Nullable Integer batchSize + ) { + var inferenceId = "inference_id"; + + return new CustomModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + CustomService.NAME, + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "{\"input\":${input}}", + parser, + new RateLimitSettings(10_000), + batchSize + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))), + chunkingSettings + ); + } + private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { return new CustomModel( "model_id", @@ -256,7 +294,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa url, Map.of("key", "value"), QueryParameters.EMPTY, - "\"input\":\"${input}\"", + "{\"input\":${input}}", customResponseParser, new RateLimitSettings(10_000) ), @@ -572,4 +610,151 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx ); } } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = createInternalEmbeddingModel( + SimilarityMeasure.DOT_PRODUCT, + new TextEmbeddingResponseParser("$.data[*].embedding"), + getUrl(webServer), + ChunkingSettingsTests.createRandomChunkingSettings(), + 2 + ); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + try (var service = createService(threadPool, clientManager)) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.123f, -0.123f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.223f, -0.223f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), is(List.of("a", "bb"))); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + try (var service = createService(threadPool, clientManager)) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("a")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(1)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.123f, -0.123f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), is(List.of("a"))); + } + } }