From 9e7cb3f971b89cacc5e267a889780b2dbd8b12ee Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 17 Jun 2025 10:23:14 -0400 Subject: [PATCH 1/2] Adding chunking tests --- .../org/elasticsearch/TransportVersions.java | 2 + .../services/custom/CustomModel.java | 39 ++++ .../services/custom/CustomService.java | 66 +++++- .../custom/CustomServiceSettings.java | 60 +++++- .../custom/CustomServiceSettingsTests.java | 52 ++++- .../services/custom/CustomServiceTests.java | 190 +++++++++++++++++- 6 files changed, 394 insertions(+), 15 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 8bf8a94fccfe0..70f591b1424ed 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -196,6 +196,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48); public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49); public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_51); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -298,6 +299,7 @@ static TransportVersion def(int id) { public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00); public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00); public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_099_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java index 7c00b0a242f94..b23f515055b9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.custom; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -51,6 +52,27 @@ public CustomModel( ); } + public CustomModel( + String inferenceId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + @Nullable ChunkingSettings chunkingSettings, + ConfigurationParseContext context + ) { + this( + inferenceId, + taskType, + service, + CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId), + CustomTaskSettings.fromMap(taskSettings), + CustomSecretSettings.fromMap(secrets), + chunkingSettings + ); + } + // should only be used for testing CustomModel( String inferenceId, @@ -67,6 +89,23 @@ public CustomModel( ); } + // should only be used for testing + CustomModel( + String inferenceId, + TaskType taskType, + String service, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings, + @Nullable ChunkingSettings chunkingSettings + ) { + this( + new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + protected CustomModel(CustomModel model, TaskSettings taskSettings) { super(model, taskSettings); rateLimitServiceSettings = model.rateLimitServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 5e9aef099f622..69e4bed34dcff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -27,6 +28,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -45,6 +48,7 @@ import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -81,12 +85,15 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + var chunkingSettings = extractChunkingSettings(config, taskType); + CustomModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, serviceSettingsMap, + chunkingSettings, ConfigurationParseContext.REQUEST ); @@ -100,6 +107,14 @@ public void parseRequestConfig( } } + private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return null; + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); @@ -125,7 +140,8 @@ private static CustomModel createModelWithoutLoggingDeprecations( TaskType taskType, Map serviceSettings, Map taskSettings, - @Nullable Map secretSettings + @Nullable Map secretSettings, + @Nullable ChunkingSettings chunkingSettings ) { return createModel( inferenceEntityId, @@ -133,6 +149,7 @@ private static CustomModel createModelWithoutLoggingDeprecations( serviceSettings, taskSettings, secretSettings, + chunkingSettings, ConfigurationParseContext.PERSISTENT ); } @@ -143,12 +160,13 @@ private static CustomModel createModel( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, + @Nullable ChunkingSettings chunkingSettings, ConfigurationParseContext context ) { if (supportedTaskTypes.contains(taskType) == false) { throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } - return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context); } @Override @@ -162,7 +180,16 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap); + var chunkingSettings = extractChunkingSettings(config, taskType); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + chunkingSettings + ); } @Override @@ -170,7 +197,16 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null); + var chunkingSettings = extractChunkingSettings(config, taskType); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + chunkingSettings + ); } @Override @@ -211,7 +247,27 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + if (model instanceof CustomModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var customModel = (CustomModel) model; + var overriddenModel = CustomModel.of(customModel, taskSettings); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME); + var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + customModel.getServiceSettings().getBatchSize(), + customModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 0d5129b6c759c..c552df1787e88 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; @@ -53,16 +54,18 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings { + public static final String NAME = "custom_service_settings"; public static final String URL = "url"; + public static final String BATCH_SIZE = "batch_size"; public static final String HEADERS = "headers"; public static final String REQUEST = "request"; public static final String RESPONSE = "response"; public static final String JSON_PARSER = "json_parser"; public static final String ERROR_PARSER = "error_parser"; - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); + private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 1; public static CustomServiceSettings fromMap( Map map, @@ -117,6 +120,8 @@ public static CustomServiceSettings fromMap( context ); + var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException); + if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) { throw validationException; } @@ -137,7 +142,8 @@ public static CustomServiceSettings fromMap( requestContentString, responseJsonParser, rateLimitSettings, - errorParser + errorParser, + batchSize ); } @@ -155,7 +161,6 @@ public record TextEmbeddingSettings( null, DenseVectorFieldMapper.ElementType.FLOAT ); - // This refers to settings that are not related to the text embedding task type (all the settings should be null) public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null); @@ -210,6 +215,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; private final ErrorResponseParser errorParser; + private final int batchSize; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -220,6 +226,30 @@ public CustomServiceSettings( CustomResponseParser responseJsonParser, @Nullable RateLimitSettings rateLimitSettings, ErrorResponseParser errorParser + ) { + this( + textEmbeddingSettings, + url, + headers, + queryParameters, + requestContentString, + responseJsonParser, + rateLimitSettings, + errorParser, + null + ); + } + + public CustomServiceSettings( + TextEmbeddingSettings textEmbeddingSettings, + String url, + @Nullable Map headers, + @Nullable QueryParameters queryParameters, + String requestContentString, + CustomResponseParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings, + ErrorResponseParser errorParser, + @Nullable Integer batchSize ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -229,6 +259,7 @@ public CustomServiceSettings( this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); this.errorParser = Objects.requireNonNull(errorParser); + this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -240,6 +271,12 @@ public CustomServiceSettings(StreamInput in) throws IOException { responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); errorParser = new ErrorResponseParser(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) { + batchSize = in.readVInt(); + } else { + batchSize = DEFAULT_EMBEDDING_BATCH_SIZE; + } } @Override @@ -291,6 +328,10 @@ public ErrorResponseParser getErrorParser() { return errorParser; } + public int getBatchSize() { + return batchSize; + } + @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; @@ -337,6 +378,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder rateLimitSettings.toXContent(builder, params); + builder.field(BATCH_SIZE, batchSize); + return builder; } @@ -360,6 +403,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); errorParser.writeTo(out); + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) { + out.writeVInt(batchSize); + } } @Override @@ -374,7 +422,8 @@ public boolean equals(Object o) { && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) && Objects.equals(rateLimitSettings, that.rateLimitSettings) - && Objects.equals(errorParser, that.errorParser); + && Objects.equals(errorParser, that.errorParser) + && Objects.equals(batchSize, that.batchSize); } @Override @@ -387,7 +436,8 @@ public int hashCode() { requestContentString, responseJsonParser, rateLimitSettings, - errorParser + errorParser, + batchSize ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 9e1d3a8f4c8f8..fe801b7322d73 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -137,7 +137,9 @@ public void testFromMap() { CustomServiceSettings.ERROR_PARSER, new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) ) - ) + ), + CustomServiceSettings.BATCH_SIZE, + 10 ) ), ConfigurationParseContext.REQUEST, @@ -161,7 +163,8 @@ public void testFromMap() { requestContentString, responseParser, new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new ErrorResponseParser("$.error.message", "inference_id"), + 10 ) ) ); @@ -652,7 +655,50 @@ public void testXContent() throws IOException { }, "rate_limit": { "requests_per_minute": 10000 - } + }, + "batch_size": 1 + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testXContent_BatchSize10() throws IOException { + var entity = new CustomServiceSettings( + CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS, + "http://www.abc.com", + Map.of("key", "value"), + null, + "string", + new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), + null, + new ErrorResponseParser("$.error.message", "inference_id"), + 10 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "url": "http://www.abc.com", + "headers": { + "key": "value" + }, + "request": "string", + "response": { + "json_parser": { + "text_embeddings": "$.result.embeddings[*].embedding" + }, + "error_parser": { + "path": "$.error.message" + } + }, + "rate_limit": { + "requests_per_minute": 10000 + }, + "batch_size": 10 } """); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index f8d650144693d..7f256e1f29c10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -11,6 +11,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -21,9 +24,11 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; @@ -36,6 +41,8 @@ import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; import java.io.IOException; import java.util.EnumSet; @@ -44,6 +51,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT; @@ -51,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -243,7 +252,7 @@ private static CustomModel createInternalEmbeddingModel( url, Map.of("key", "value"), QueryParameters.EMPTY, - "\"input\":\"${input}\"", + "{\"input\":${input}}", parser, new RateLimitSettings(10_000), new ErrorResponseParser("$.error.message", inferenceId) @@ -253,6 +262,36 @@ private static CustomModel createInternalEmbeddingModel( ); } + private static CustomModel createInternalEmbeddingModel( + @Nullable SimilarityMeasure similarityMeasure, + TextEmbeddingResponseParser parser, + String url, + @Nullable ChunkingSettings chunkingSettings, + @Nullable Integer batchSize + ) { + var inferenceId = "inference_id"; + + return new CustomModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + CustomService.NAME, + new CustomServiceSettings( + new CustomServiceSettings.TextEmbeddingSettings(similarityMeasure, 123, 456, DenseVectorFieldMapper.ElementType.FLOAT), + url, + Map.of("key", "value"), + QueryParameters.EMPTY, + "{\"input\":${input}}", + parser, + new RateLimitSettings(10_000), + new ErrorResponseParser("$.error.message", inferenceId), + batchSize + ), + new CustomTaskSettings(Map.of("key", "test_value")), + new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))), + chunkingSettings + ); + } + private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { var inferenceId = "inference_id"; @@ -265,7 +304,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa url, Map.of("key", "value"), QueryParameters.EMPTY, - "\"input\":\"${input}\"", + "{\"input\":${input}}", customResponseParser, new RateLimitSettings(10_000), new ErrorResponseParser("$.error.message", inferenceId) @@ -546,4 +585,151 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx ); } } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = createInternalEmbeddingModel( + SimilarityMeasure.DOT_PRODUCT, + new TextEmbeddingResponseParser("$.data[*].embedding"), + getUrl(webServer), + ChunkingSettingsTests.createRandomChunkingSettings(), + 2 + ); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + try (var service = createService(threadPool, clientManager)) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.123f, -0.123f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 2), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.223f, -0.223f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), is(List.of("a", "bb"))); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + try (var service = createService(threadPool, clientManager)) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("a")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(1)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, 1), floatResult.chunks().get(0).offset()); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertArrayEquals( + new float[] { 0.123f, -0.123f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values(), + 0.0f + ); + } + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), is(List.of("a"))); + } + } } From 638d65872d0080be7eb1cf946e9715590654771e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 18 Jun 2025 13:10:24 -0400 Subject: [PATCH 2/2] adjusting default batch size --- .../inference/services/custom/CustomServiceSettingsTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index bc9faafdd0021..47346fc896baa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -583,7 +583,7 @@ public void testXContent() throws IOException { "rate_limit": { "requests_per_minute": 10000 }, - "batch_size": 1 + "batch_size": 10 } """);