From f054dca0b3d83eb36d6feed9fbcce331e4633852 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 13:19:47 +0200 Subject: [PATCH 01/23] Add working dense text embeddings integration with default endpoint. Some tests WIP --- .../org/elasticsearch/TransportVersions.java | 2 + ...etModelsWithElasticInferenceServiceIT.java | 3 +- .../inference/InferenceGetServicesIT.java | 3 +- ...icInferenceServiceAuthorizationServer.java | 4 + .../InferenceRevokeDefaultEndpointsIT.java | 34 +- ...enceServiceDenseTextEmbeddingsRequest.java | 93 +++++ ...rviceDenseTextEmbeddingsRequestEntity.java | 57 +++ ...nceServiceAuthorizationResponseEntity.java | 4 +- ...viceDenseTextEmbeddingsResponseEntity.java | 107 ++++++ ...ServiceSparseEmbeddingsResponseEntity.java | 10 +- .../elastic/ElasticInferenceService.java | 151 +++++++- .../elastic/ElasticInferenceServiceModel.java | 17 +- ...ElasticInferenceServiceRequestManager.java | 4 +- .../ElasticInferenceServiceActionCreator.java | 35 ++ .../ElasticInferenceServiceActionVisitor.java | 2 + ...erenceServiceDenseTextEmbeddingsModel.java | 114 ++++++ ...iceDenseTextEmbeddingsServiceSettings.java | 263 +++++++++++++ .../elastic/ElasticInferenceServiceTests.java | 360 +++++++++--------- ...eServiceDenseTextEmbeddingsModelTests.java | 43 +++ 19 files changed, 1092 insertions(+), 214 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 00943d04275dd..67e69c0136074 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -157,6 +157,7 @@ static TransportVersion def(int id) { public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_17); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -214,6 +215,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_048_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 42289c50864e6..5c95150f1a885 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -26,7 +26,7 @@ public void testGetDefaultEndpoints() throws IOException { var allModels = getAllModels(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(allModels, hasSize(5)); + assertThat(allModels, hasSize(6)); assertThat(chatCompletionModels, hasSize(1)); for (var model : chatCompletionModels) { @@ -35,6 +35,7 @@ public void testGetDefaultEndpoints() throws IOException { assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING); } private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 6f9a550481049..f180b995eb8c7 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(15)); + assertThat(services.size(), equalTo(16)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -79,6 +79,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "elastic", "elasticsearch", "googleaistudio", "googlevertexai", diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 3ea011c1317cc..a032a78a942e0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -36,6 +36,10 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 86c1b549d9de5..6fabfdcae4844 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; @@ -197,6 +198,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } @@ -221,16 +226,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service ) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat( + listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), + is(".multilingual-embed-elastic") + ); var getModelListener = new PlainActionFuture(); // persists the default endpoints @@ -267,6 +289,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service ) ) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java new file mode 100644 index 0000000000000..8c43f0e12530c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest { + + private final URI uri; + private final ElasticInferenceServiceDenseTextEmbeddingsModel model; + private final List inputs; + private final TraceContextHandler traceContextHandler; + private final InputType inputType; + + public ElasticInferenceServiceDenseTextEmbeddingsRequest( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + List inputs, + TraceContext traceContext, + ElasticInferenceServiceRequestMetadata metadata, + InputType inputType + ) { + super(metadata); + this.inputs = inputs; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + this.traceContextHandler = new TraceContextHandler(traceContext); + this.inputType = inputType; + } + + @Override + public HttpRequestBase createHttpRequestBase() { + var httpPost = new HttpPost(uri); + var usageContext = inputTypeToUsageContext(inputType); + var requestEntity = Strings.toString( + new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + traceContextHandler.propagateTraceContext(httpPost); + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return httpPost; + } + + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..99ea5e8d6d3e7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List inputs, + String modelId, + @Nullable ElasticInferenceServiceUsageContext usageContext +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + private static final String USAGE_CONTEXT = "usage_context"; + + public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(INPUT_FIELD); + + for (String input : inputs) { + builder.value(input); + } + + builder.endArray(); + + builder.field(MODEL_FIELD, modelId); + + // optional field + if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) { + builder.field(USAGE_CONTEXT, usageContext); + } + + builder.endObject(); + + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java index 30c460d5d1ede..3891850ce122f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer "embed/text/sparse", TaskType.SPARSE_EMBEDDING, "chat", - TaskType.CHAT_COMPLETION + TaskType.CHAT_COMPLETION, + "embed/text/dense", + TaskType.TEXT_EMBEDDING ); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..d80b9c74a6c6c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Elastic Inference Service dense text embeddings response"; + + /** + * Parses the Elastic Inference Service Dense Text Embeddings response. + * + * For a request like: + * + *
+     *     
+     *         {
+     *             "inputs": ["Embed this text", "Embed this text, too"]
+     *         }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *             "data": [
+     *                  [
+     *                      2.1259406,
+     *                      1.7073475,
+     *                      0.9020516
+     *                  ],
+     *                  (...)
+     *             ],
+     *             "meta": {
+     *                  "usage": {...}
+     *             }
+     *         }
+     *     
+     * 
+ */ + + public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + List parsedEmbeddings = parseList( + jsonParser, + (parser, index) -> ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.parseTextEmbeddingObject(parser) + ); + + if (parsedEmbeddings.isEmpty()) { + return new TextEmbeddingFloatResults(Collections.emptyList()); + } + + return new TextEmbeddingFloatResults(parsedEmbeddings); + } + } + + private static TextEmbeddingFloatResults.Embedding parseTextEmbeddingObject(XContentParser parser) throws IOException { + List embeddingValueList = parseList( + parser, + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::parseEmbeddingFloatValueList + ); + return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); + } + + private static float parseEmbeddingFloatValueList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java index 42ca45f75a9c0..af71ccb6dc43f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java @@ -51,11 +51,11 @@ public class ElasticInferenceServiceSparseEmbeddingsResponseEntity { * * { * "data": [ - * { - * "Embed": 2.1259406, - * "this": 1.7073475, - * "text": 0.9020516 - * }, + * [ + * 2.1259406, + * 1.7073475, + * 0.9020516 + * ], * (...) * ], * "meta": { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 75a9e44d25b62..99f7287e785d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,7 +16,9 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -27,6 +29,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -36,6 +39,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -51,6 +56,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -70,6 +77,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -79,10 +87,18 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); + private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.CHAT_COMPLETION, + TaskType.TEXT_EMBEDDING + ); private static final String SERVICE_NAME = "Elastic"; + // TODO: check with team, what makes the most sense + private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 32; + // rainbow-sprinkles static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); @@ -91,10 +107,17 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2"; static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); + // multilingual-text-embed + static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed"; + static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + /** * The task types that the {@link InferenceAction.Request} can accept. */ - private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.TEXT_EMBEDDING + ); public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); @@ -155,6 +178,31 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents ), MinimalServiceSettings.sparseEmbedding(NAME) + ), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new DefaultModelConfig( + new ElasticInferenceServiceDenseTextEmbeddingsModel( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + null, + null, + false, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.textEmbedding( + NAME, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ) ) ); } @@ -270,12 +318,26 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - // Pass-through without actually performing chunking (result will have a single chunk per input) - ActionListener inferListener = listener.delegateFailureAndWrap( - (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) - ); + // TODO: we probably want to allow chunked inference for both sparse and dense? + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } - doInfer(model, inputs, taskSettings, timeout, inferListener); + ElasticInferenceServiceDenseTextEmbeddingsModel elasticInferenceServiceModel = + (ElasticInferenceServiceDenseTextEmbeddingsModel) model; + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, + elasticInferenceServiceModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } } @Override @@ -294,11 +356,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + ElasticInferenceServiceModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, elasticInferenceServiceComponents, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), @@ -335,6 +405,7 @@ private static ElasticInferenceServiceModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, ElasticInferenceServiceComponents eisServiceComponents, String failureMessage, @@ -361,6 +432,16 @@ private static ElasticInferenceServiceModel createModel( eisServiceComponents, context ); + case TEXT_EMBEDDING -> new ElasticInferenceServiceDenseTextEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + eisServiceComponents, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -376,11 +457,17 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -391,11 +478,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -411,6 +504,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage ) { @@ -419,6 +513,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, elasticInferenceServiceComponents, failureMessage, @@ -432,6 +527,36 @@ public void checkModelConfig(Model model, ActionListener listener) { ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); } + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var modelId = serviceSettings.modelId(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + var dimensionsSetByUser = serviceSettings.dimensionsSetByUser(); + + var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarityToUse, + embeddingSize, + maxInputTokens, + dimensionsSetByUser, + serviceSettings.rateLimitSettings() + ); + + return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + // TODO: double-check + return SimilarityMeasure.COSINE; + } + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); @@ -469,9 +594,9 @@ private LazyInitializable initC configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( - "The name of the model to use for the inference task." - ) + new SettingsConfiguration.Builder( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ).setDescription("The name of the model to use for the inference task.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -482,7 +607,7 @@ private LazyInitializable initC configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( "Allows you to specify the maximum number of tokens per input." ) .setLabel("Maximum Input Tokens") @@ -494,7 +619,9 @@ private LazyInitializable initC ); configurationMap.putAll( - RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)) + RateLimitSettings.toSettingsConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ) ); return new InferenceServiceConfiguration.Builder().setService(NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index e03cc36e62417..34a8086119150 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -7,14 +7,15 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class ElasticInferenceServiceModel extends Model { +public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel { private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; @@ -35,12 +36,18 @@ public ElasticInferenceServiceModel( public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); - this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.rateLimitServiceSettings = model.rateLimitServiceSettings; this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents(); } - public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() { - return rateLimitServiceSettings; + @Override + public int rateLimitGroupingHash() { + // We only have one model for rerank + return Objects.hash(this.getServiceSettings().modelId()); + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); } public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java index 81231b83c767e..bcb7d395b7e31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java @@ -20,7 +20,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM private final ElasticInferenceServiceRequestMetadata requestMetadata; protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { - super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings()); this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); } @@ -32,7 +32,7 @@ record RateLimitGrouping(int modelIdHash) { public static RateLimitGrouping of(ElasticInferenceServiceModel model) { Objects.requireNonNull(model); - return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + return new RateLimitGrouping(model.rateLimitGroupingHash()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 5bdae8582f371..18080d9c055ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -9,9 +9,16 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceDenseTextEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -19,10 +26,16 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { + public static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( + "elastic dense text embedding", + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse + ); + private final Sender sender; private final ServiceComponents serviceComponents; @@ -43,4 +56,26 @@ public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel mode ); return new SenderExecutableAction(sender, requestManager, errorMessage); } + + @Override + public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) { + var threadPool = serviceComponents.threadPool(); + + var manager = new GenericRequestManager<>( + threadPool, + model, + DENSE_TEXT_EMBEDDINGS_HANDLER, + (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( + model, + embeddingsInput.getStringInputs(), + traceContext, + extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), + embeddingsInput.getInputType() + ), + EmbeddingsInput.class + ); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings"); + return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index 2fdd90b8169bd..c19eeb0f37961 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -8,10 +8,12 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); + ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java new file mode 100644 index 0000000000000..8c2066ba7046e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; +import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { + + private final URI uri; + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context), + // TODO: we probably want dense embeddings task settings + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings, + // TODO: we probably want dense embeddings task settings + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + this.uri = createUri(); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings + ) { + super(model, serviceSettings); + this.uri = createUri(); + } + + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + @Override + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() { + return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/dense"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..93a5be16aeac0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -0,0 +1,263 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.*; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + + private final String modelId; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final boolean dimensionsSetByUser; + private final RateLimitSettings rateLimitSettings; + + public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap( + Map map, + ConfigurationParseContext context + ) { + return switch (context) { + case REQUEST -> fromRequestMap(map, context); + case PERSISTENT -> fromPersistentMap(map, context); + }; + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var dimensionsSetByUser = dims != null; + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + + if (dimensionsSetByUser == null) { + dimensionsSetByUser = Boolean.FALSE; + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + String modelId, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + boolean dimensionsSetByUser, + RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.dimensionsSetByUser = dimensionsSetByUser; + this.rateLimitSettings = rateLimitSettings; + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + this.dimensionsSetByUser = in.readBoolean(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public Boolean dimensionsSetByUser() { + return dimensionsSetByUser; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + public RateLimitSettings getRateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + toXContentFragmentOfExposedFields(builder, params); + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeBoolean(dimensionsSetByUser); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o; + return Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(similarity, dimensions, maxInputTokens, dimensionsSetByUser); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6a93b1cc19c87..5b65d96cfbfee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; @@ -38,11 +39,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -59,6 +58,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -85,6 +86,7 @@ import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -394,47 +396,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException verifyNoMoreInteractions(sender); } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { - var sender = mock(Sender.class); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING); - - try (var service = createService(factory)) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat( - thrownException.getMessage(), - is( - "Inference entity [model_id] does not support task type [text_embedding] " - + "for inference, the task type must be one of [sparse_embedding]." - ) - ); - - verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); - } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); - } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { var sender = mock(Sender.class); @@ -463,7 +424,7 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [sparse_embedding]. " + + "for inference, the task type must be one of [text_embedding, sparse_embedding]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -604,82 +565,6 @@ public void testInfer_PropagatesProductUseCaseHeader() throws IOException { } } - public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { - String responseJson = """ - { - "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - // Set up the product use case in the thread context - String productUseCase = "test-product-use-case"; - threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); - PlainActionFuture> listener = new PlainActionFuture<>(); - - try { - service.chunkedInfer( - model, - null, - List.of(new ChunkInferenceInput("input text")), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var results = listener.actionGet(TIMEOUT); - - // Verify the response was processed correctly - ChunkedInference inferenceResult = results.getFirst(); - assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) inferenceResult; - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) - ); - - // Verify the request was sent and contains expected headers - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - var request = webServer.requests().getFirst(); - assertNull(request.getUri().getQuery()); - MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - - // Check that the product use case header was set correctly - assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); - - // Verify request body - var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); - } finally { - // Clean up the thread context - threadPool.getThreadContext().stashContext(); - } - } - } - public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException { var elasticInferenceServiceURL = getUrl(webServer); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -738,30 +623,45 @@ public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws I } } - public void testChunkedInfer_PassesThrough() throws IOException { + public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs String responseJson = """ { "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 } - ] + } } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("input text")), + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -769,32 +669,123 @@ public void testChunkedInfer_PassesThrough() throws IOException { ); var results = listener.actionGet(TIMEOUT); - assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) results.get(0); - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) + assertThat(results, hasSize(2)); + + // Verify the response was processed correctly + ChunkedInference inferenceResult = results.getFirst(); + assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); + + // Verify the request was sent and contains expected headers + assertThat(webServer.requests(), hasSize(1)); + var request = webServer.requests().getFirst(); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + // Check that the product use case header was set correctly + assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + + public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel( + getUrl(webServer), + "my-dense-model-id", + createRandomChunkingSettings() + ); + + testChunkedInfer_BatchesCalls(model); + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + + testChunkedInfer_BatchesCalls(model); + } + + private void testChunkedInfer_BatchesCalls(ElasticInferenceServiceDenseTextEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs + String responseJson = """ + { + "data": [ + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 + } + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // First result + { + assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.getFirst(); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f); + } + + // Second result + { + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); + } + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest")) + ); } } @@ -806,27 +797,6 @@ public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() thr } } - public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - assertTrue(service.hideFromConfigurationApi()); - } - } - public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( @@ -856,7 +826,7 @@ public void testGetConfiguration() throws Exception { List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) ) ) ) @@ -869,7 +839,7 @@ public void testGetConfiguration() throws Exception { { "service": "elastic", "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion"], + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -878,7 +848,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -887,7 +857,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -896,7 +866,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -933,7 +903,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -942,7 +912,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -951,7 +921,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -993,7 +963,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO { "service": "elastic", "name": "Elastic", - "task_types": [], + "task_types": ["text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1002,7 +972,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1011,7 +981,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -1020,7 +990,7 @@ public void testGetConfiguration_WithoutSupportedTaskTypes_WhenModelsReturnTaskO "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -1197,6 +1167,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } @@ -1218,6 +1192,16 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service + ), new InferenceService.DefaultConfigId( ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), @@ -1226,14 +1210,18 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(2)); + assertThat(models.size(), is(3)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-elastic")); + assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java new file mode 100644 index 0000000000000..753e5bb17303d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public class ElasticInferenceServiceDenseTextEmbeddingsModelTests { + + public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel( + String url, + String modelId, + ChunkingSettings chunkingSettings + ) { + return new ElasticInferenceServiceDenseTextEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "elastic", + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + SimilarityMeasure.COSINE, + null, + null, + false, + new RateLimitSettings(1000L) + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + } + +} From 6584dabca00fe2acd668d6d1fca952702b2ef7e4 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 14:15:30 +0200 Subject: [PATCH 02/23] Fix merge conflicts, compilation errors and test failures --- .../ElasticInferenceServiceActionCreator.java | 11 +++-- ...enceServiceDenseTextEmbeddingsRequest.java | 4 +- ...rviceDenseTextEmbeddingsRequestEntity.java | 2 +- .../ElasticInferenceServiceRerankRequest.java | 4 +- ...icInferenceServiceRerankRequestEntity.java | 2 +- ...erenceServiceRerankRequestEntityTests.java | 2 +- ...ticInferenceServiceRerankRequestTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 42 ------------------- 8 files changed, 12 insertions(+), 57 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{external/request/elastic => services/elastic/request}/ElasticInferenceServiceDenseTextEmbeddingsRequest.java (94%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{external/request/elastic => services/elastic/request}/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java (96%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{external/request/elastic/rerank => services/elastic/request}/ElasticInferenceServiceRerankRequest.java (91%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{external/request/elastic/rerank => services/elastic/request}/ElasticInferenceServiceRerankRequestEntity.java (95%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 90abe9a780187..8b987cd53bc81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -11,26 +11,25 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; -import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; -import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceDenseTextEmbeddingsRequest; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java similarity index 94% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java index 8c43f0e12530c..8a873504ee128 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; @@ -25,7 +25,7 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java index 99ea5e8d6d3e7..c149e6cf67063 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java similarity index 91% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java index 08b3fd2384642..63b26f2a1223b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic.rerank; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; @@ -15,8 +15,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java similarity index 95% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java index b542af93047fa..1e21b6f7d8eeb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic.rerank; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java index 407d3e38b4da1..a484c690b260c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequestEntity; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequestEntity; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java index 4e6efed6faa59..58a357684961c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java @@ -10,7 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 5eeec40d17f4a..dffcf5f604f55 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -30,7 +30,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.http.MockResponse; @@ -424,47 +423,6 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc } } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { - var sender = mock(Sender.class); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING); - - try (var service = createService(factory)) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat( - thrownException.getMessage(), - is( - "Inference entity [model_id] does not support task type [text_embedding] " - + "for inference, the task type must be one of [sparse_embedding, rerank]." - ) - ); - - verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); - } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); - } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { var sender = mock(Sender.class); From 9d47176e4785b44b16efce52e3acacee637646d7 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 14:16:26 +0200 Subject: [PATCH 03/23] Spotless apply --- .../services/elastic/ElasticInferenceService.java | 8 ++++++-- .../action/ElasticInferenceServiceActionVisitor.java | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 41d81438ee251..e278068ab29bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -56,9 +56,9 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -115,7 +115,11 @@ public class ElasticInferenceService extends SenderService { /** * The task types that the {@link InferenceAction.Request} can accept. */ - private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING); + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK, + TaskType.TEXT_EMBEDDING + ); public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index eeb0cdc571f45..4f8a9c9ec20a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; public interface ElasticInferenceServiceActionVisitor { From 3e8c70a47ab000e38e3e02d91b3a271500447b9c Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:12:15 +0200 Subject: [PATCH 04/23] Add ElasticInferenceServiceDenseTextEmbeddingsRequestTests --- ...erviceDenseTextEmbeddingsRequestTests.java | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..75a6c8a32b65a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequestTests extends ESTestCase { + + public void testCreateHttpRequest_UsageContextSearch() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.SEARCH); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("search")); + } + + public void testCreateHttpRequest_UsageContextIngest() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("ingest text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.INGEST); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("ingest")); + } + + public void testCreateHttpRequest_UsageContextUnspecified() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("unspecified text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + // usage_context should not be present for UNSPECIFIED + } + + public void testCreateHttpRequest_MultipleInputs() throws IOException { + var url = "http://eis-gateway.com"; + var inputs = List.of("first input", "second input", "third input"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, inputs, InputType.SEARCH); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(inputs)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("search")); + } + + public void testTraceContextPropagatedThroughHTTPHeaders() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var traceParent = request.getTraceContext().traceParent(); + var traceState = request.getTraceContext().traceState(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + } + + public void testTruncate_ReturnsSameInstance() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var truncatedRequest = request.truncate(); + + // Dense text embeddings request doesn't support truncation, should return same instance + assertThat(truncatedRequest, is(request)); + } + + public void testGetTruncationInfo_ReturnsNull() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + + // Dense text embeddings request doesn't support truncation info + assertThat(request.getTruncationInfo(), is(nullValue())); + } + + private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( + String url, + String modelId, + List inputs, + InputType inputType + ) { + var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId, null); + + return new ElasticInferenceServiceDenseTextEmbeddingsRequest( + embeddingsModel, + inputs, + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata(), + inputType + ); + } +} From 23e7595be11714abb7031d380662396c1d7a9199 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:16:44 +0200 Subject: [PATCH 05/23] Add ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests --- ...DenseTextEmbeddingsRequestEntityTests.java | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..f0ac37174f155 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests extends ESTestCase { + + public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id" + }""")); + } + + public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc", "def"), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "abc", + "def" + ], + "model": "my-model-id" + } + """)); + } + + public void testToXContent_SingleInput_SearchUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.SEARCH + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id", + "usage_context": "search" + } + """)); + } + + public void testToXContent_SingleInput_IngestUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.INGEST + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id", + "usage_context": "ingest" + } + """)); + } + + public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("first input", "second input", "third input"), + "my-dense-model", + ElasticInferenceServiceUsageContext.SEARCH + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "first input", + "second input", + "third input" + ], + "model": "my-dense-model", + "usage_context": "search" + } + """)); + } + + public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("document one", "document two"), + "embedding-model-v2", + ElasticInferenceServiceUsageContext.INGEST + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "document one", + "document two" + ], + "model": "embedding-model-v2", + "usage_context": "ingest" + } + """)); + } + + public void testToXContent_EmptyInput_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of(""), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [""], + "model": "my-model-id" + } + """)); + } + + private String xContentEntityToString(ElasticInferenceServiceDenseTextEmbeddingsRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + return Strings.toString(builder); + } +} From 5af751695a8fc31cf7df1316a449da2771085ed7 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:19:52 +0200 Subject: [PATCH 06/23] Add "-v1" to multilingual-embed --- .../InferenceGetModelsWithElasticInferenceServiceIT.java | 2 +- .../MockElasticInferenceServiceAuthorizationServer.java | 2 +- .../integration/InferenceRevokeDefaultEndpointsIT.java | 8 ++++---- .../services/elastic/ElasticInferenceService.java | 2 +- .../services/elastic/ElasticInferenceServiceTests.java | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 1f75ffb902428..83141028156e5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -42,7 +42,7 @@ public void testGetDefaultEndpoints() throws IOException { assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); - assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING); + assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING); } private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index f239c7de8b0f1..36b7dfeae9850 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -43,7 +43,7 @@ public void enqueueAuthorizeAllModelsResponse() { "task_types": ["embed/text/sparse"] }, { - "model_name": "multilingual-embed", + "model_name": "multilingual-embed-v1", "task_types": ["embed/text/dense"] } ] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 6fabfdcae4844..be767085263d4 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -200,7 +200,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA "task_types": ["embed/text/sparse"] }, { - "model_name": "multilingual-embed", + "model_name": "multilingual-embed-v1", "task_types": ["embed/text/dense"] } ] @@ -228,7 +228,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA service ), new InferenceService.DefaultConfigId( - ".multilingual-embed-elastic", + ".multilingual-embed-v1-elastic", MinimalServiceSettings.textEmbedding( ElasticInferenceService.NAME, ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, @@ -251,7 +251,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); assertThat( listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), - is(".multilingual-embed-elastic") + is(".multilingual-embed-v1-elastic") ); var getModelListener = new PlainActionFuture(); @@ -291,7 +291,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA service ), new InferenceService.DefaultConfigId( - ".multilingual-embed-elastic", + ".multilingual-embed-v1-elastic", MinimalServiceSettings.textEmbedding( ElasticInferenceService.NAME, ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index e278068ab29bc..649684adca94d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -109,7 +109,7 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); // multilingual-text-embed - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed"; + static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); /** diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index dffcf5f604f55..5e7e29c5594a1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1266,7 +1266,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() "task_types": ["embed/text/sparse"] }, { - "model_name": "multilingual-embed", + "model_name": "multilingual-embed-v1", "task_types": ["embed/text/dense"] } ] @@ -1290,7 +1290,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() service ), new InferenceService.DefaultConfigId( - ".multilingual-embed-elastic", + ".multilingual-embed-v1-elastic", MinimalServiceSettings.textEmbedding( ElasticInferenceService.NAME, ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, @@ -1317,7 +1317,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() var models = listener.actionGet(TIMEOUT); assertThat(models.size(), is(3)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-elastic")); + assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic")); assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } From fddfd9d01127b2cd2830e38c9b169db79eddb7a4 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:33:04 +0200 Subject: [PATCH 07/23] Add ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java --- ...iceDenseTextEmbeddingsServiceSettings.java | 14 +- ...nseTextEmbeddingsServiceSettingsTests.java | 224 ++++++++++++++++++ 2 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java index 93a5be16aeac0..e08f54467cc45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -146,7 +146,7 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; this.dimensionsSetByUser = dimensionsSetByUser; - this.rateLimitSettings = rateLimitSettings; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException { @@ -239,10 +239,12 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); out.writeBoolean(dimensionsSetByUser); + rateLimitSettings.writeTo(out); } @Override @@ -250,14 +252,16 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o; - return Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) - && Objects.equals(similarity, that.similarity) + return dimensionsSetByUser == that.dimensionsSetByUser + && Objects.equals(modelId, that.modelId) + && similarity == that.similarity && Objects.equals(dimensions, that.dimensions) - && Objects.equals(maxInputTokens, that.maxInputTokens); + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override public int hashCode() { - return Objects.hash(similarity, dimensions, maxInputTokens, dimensionsSetByUser); + return Objects.hash(modelId, similarity, dimensions, maxInputTokens, dimensionsSetByUser, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..6a9dfb02d13bc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -0,0 +1,224 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceDenseTextEmbeddingsServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings mutateInstance( + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings instance + ) throws IOException { + return randomValueOtherThan(instance, ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests::createRandom); + } + + public void testFromMap_Request_WithAllSettings() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.COSINE; + var dimensions = 384; + var maxInputTokens = 512; + + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings.modelId(), is(modelId)); + assertThat(serviceSettings.similarity(), is(similarity)); + assertThat(serviceSettings.dimensions(), is(dimensions)); + assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); + assertThat(serviceSettings.dimensionsSetByUser(), is(true)); // dimensions were provided + } + + public void testFromMap_Persistent_WithDimensionsSetByUser() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.DOT_PRODUCT; + var dimensions = 768; + var dimensionsSetByUser = true; + + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + dimensionsSetByUser + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings.modelId(), is(modelId)); + assertThat(serviceSettings.similarity(), is(similarity)); + assertThat(serviceSettings.dimensions(), is(dimensions)); + assertThat(serviceSettings.dimensionsSetByUser(), is(dimensionsSetByUser)); + } + + public void testFromMap_Persistent_WithoutDimensionsSetByUser_DefaultsToFalse() { + var modelId = "my-dense-model-id"; + + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings.dimensionsSetByUser(), is(false)); + } + + public void testToXContent_WritesAllFields() throws IOException { + var modelId = "my-dense-model"; + var similarity = SimilarityMeasure.DOT_PRODUCT; + var dimensions = 1024; + var maxInputTokens = 256; + var dimensionsSetByUser = true; + var rateLimitSettings = new RateLimitSettings(5000); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dimensions, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat( + xContentResult, + is( + Strings.format( + """ + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":%s}""", + similarity, + dimensions, + maxInputTokens, + modelId, + rateLimitSettings.requestsPerTimeUnit(), + dimensionsSetByUser + ) + ) + ); + } + + public void testToXContent_WritesOnlyNonNullFields() throws IOException { + var modelId = "my-dense-model"; + var rateLimitSettings = new RateLimitSettings(2000); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + null, // similarity + null, // dimensions + null, // maxInputTokens + false, // dimensionsSetByUser + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat( + xContentResult, + is( + Strings.format( + """ + {"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":false}""", + modelId, + rateLimitSettings.requestsPerTimeUnit() + ) + ) + ); + } + + public void testToXContentFragmentOfExposedFields() throws IOException { + var modelId = "my-dense-model"; + var rateLimitSettings = new RateLimitSettings(1500); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + SimilarityMeasure.COSINE, + 512, + 128, + true, + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + serviceSettings.toXContentFragmentOfExposedFields(builder, null); + builder.endObject(); + String xContentResult = Strings.toString(builder); + + // Only model_id and rate_limit should be in exposed fields + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + } + + public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(10); + var similarity = SimilarityMeasure.COSINE; + var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null; + var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; + var dimensionsSetByUser = randomBoolean(); + var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null; + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dimensions, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + } +} From 9b48dfb8c4c7f23899e90026d8fd74d8ddee939f Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:39:33 +0200 Subject: [PATCH 08/23] Add dense text embedding test cases to ElasticInferenceServiceActionCreatorTests --- ...ticInferenceServiceActionCreatorTests.java | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 49957800f3a83..66e15d71a038c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -22,12 +22,14 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; @@ -256,6 +258,213 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc } } + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + [ + 2.1259406, + 1.7073475, + 0.9020516 + ], + [ + 1.8342123, + 2.3456789, + 0.7654321 + ] + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(2)); + + var firstEmbedding = textEmbeddingResults.embeddings().get(0); + assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f})); + + var secondEmbedding = textEmbeddingResults.embeddings().get(1); + assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f})); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world", "second text")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + } + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithUsageContext() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + [ + 0.1234567, + 0.9876543 + ] + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(1)); + + var embedding = textEmbeddingResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f})); + + assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(3)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("search query")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + assertThat(requestMap.get("usage_context"), is("search")); + } + } + + @SuppressWarnings("unchecked") + public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + // This will fail because the expected output is {"data": [[...]]} + String responseJson = """ + { + "data": { + "embedding": [2.1259406, 1.7073475] + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + } + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_EmptyEmbeddings() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(0)); + + assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, hasSize(0)); + } + } + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); From dbdadbee4f1303fddf2756aa0924b2b0356c2084 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 23 Jun 2025 13:47:32 +0000 Subject: [PATCH 09/23] [CI] Auto commit changes from spotless --- .../ElasticInferenceServiceActionCreatorTests.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 66e15d71a038c..e63559ffd824d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -302,10 +302,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() assertThat(textEmbeddingResults.embeddings(), hasSize(2)); var firstEmbedding = textEmbeddingResults.embeddings().get(0); - assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f})); + assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f })); var secondEmbedding = textEmbeddingResults.embeddings().get(1); - assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f})); + assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f })); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -358,7 +358,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W assertThat(textEmbeddingResults.embeddings(), hasSize(1)); var embedding = textEmbeddingResults.embeddings().get(0); - assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f})); + assertThat(embedding.values(), is(new float[] { 0.1234567f, 0.9876543f })); assertThat(webServer.requests(), hasSize(1)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); @@ -445,11 +445,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); From e2f872e44da13b6ea7b2ad912a59dbd72e47c5a7 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 15:53:46 +0200 Subject: [PATCH 10/23] Add ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests --- ...ticInferenceServiceActionCreatorTests.java | 12 +- ...enseTextEmbeddingsResponseEntityTests.java | 124 ++++++++++++++++++ 2 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 66e15d71a038c..e63559ffd824d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -302,10 +302,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() assertThat(textEmbeddingResults.embeddings(), hasSize(2)); var firstEmbedding = textEmbeddingResults.embeddings().get(0); - assertThat(firstEmbedding.values(), is(new float[]{2.1259406f, 1.7073475f, 0.9020516f})); + assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f })); var secondEmbedding = textEmbeddingResults.embeddings().get(1); - assertThat(secondEmbedding.values(), is(new float[]{1.8342123f, 2.3456789f, 0.7654321f})); + assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f })); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); @@ -358,7 +358,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W assertThat(textEmbeddingResults.embeddings(), hasSize(1)); var embedding = textEmbeddingResults.embeddings().get(0); - assertThat(embedding.values(), is(new float[]{0.1234567f, 0.9876543f})); + assertThat(embedding.values(), is(new float[] { 0.1234567f, 0.9876543f })); assertThat(webServer.requests(), hasSize(1)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); @@ -445,11 +445,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..2883a1ab73c21 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; + +public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests extends ESTestCase { + + public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + 1.23, + 4.56, + 7.89 + ] + ] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(1)); + + var embedding = parsedResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f })); + } + + public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + 1.23, + 4.56, + 7.89 + ], + [ + 0.12, + 0.34, + 0.56 + ] + ] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(2)); + + var firstEmbedding = parsedResults.embeddings().get(0); + assertThat(firstEmbedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f })); + + var secondEmbedding = parsedResults.embeddings().get(1); + assertThat(secondEmbedding.values(), is(new float[] { 0.12f, 0.34f, 0.56f })); + } + + public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception { + String responseJson = """ + { + "data": [] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(0)); + } + + public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + -1.0, + 0.0, + 1.0 + ] + ], + "meta": { + "usage": { + "total_tokens": 5 + } + } + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(1)); + + var embedding = parsedResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[] { -1.0f, 0.0f, 1.0f })); + } +} From 6a35870ebf1c025ffbe43392c421f186db57e904 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 23 Jun 2025 16:55:56 +0200 Subject: [PATCH 11/23] Fix compilation error after resolving merge conflict and spotlessAppl --- .../InferenceRevokeDefaultEndpointsIT.java | 14 +++++++------- .../services/elastic/ElasticInferenceService.java | 11 +++++------ ...nferenceServiceAuthorizationResponseEntity.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 5 ++++- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 79aeb6cc028b8..1aa1f0532a51e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -239,10 +239,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), DenseVectorFieldMapper.ElementType.FLOAT ), - service - ), -new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", MinimalServiceSettings.rerank(ElasticInferenceService.NAME), service ) @@ -312,12 +312,12 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), DenseVectorFieldMapper.ElementType.FLOAT ), - service - ), + service + ), new InferenceService.DefaultConfigId( ".rerank-v1-elastic", MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service + service ) ) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 9acea1138362c..0346cdec38a7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -117,7 +117,6 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); - /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -201,18 +200,18 @@ private static Map initDefaultEndpoints( false, ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS ), - EmptyTaskSettings.INSTANCE, + EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents - ), - MinimalServiceSettings.textEmbedding( + ), + MinimalServiceSettings.textEmbedding( NAME, DENSE_TEXT_EMBEDDINGS_DIMENSIONS, defaultDenseTextEmbeddingsSimilarity(), DenseVectorFieldMapper.ElementType.FLOAT ) - ), - + ), + DEFAULT_RERANK_MODEL_ID_V1, new DefaultModelConfig( new ElasticInferenceServiceRerankModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java index d009b811ed662..451c601e7cc91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -45,7 +45,7 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer "chat", TaskType.CHAT_COMPLETION, "embed/text/dense", - TaskType.TEXT_EMBEDDING + TaskType.TEXT_EMBEDDING, "rerank/text/text-similarity", TaskType.RERANK ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index df7686f797529..d6318b798238a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1316,7 +1316,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); From 3b486b7bd9fca5bce5329089d76454e0cf9e9e70 Mon Sep 17 00:00:00 2001 From: Brendan Jugan Date: Mon, 23 Jun 2025 14:07:20 -0400 Subject: [PATCH 12/23] remove dimensions_set_by_user --- .../elastic/ElasticInferenceService.java | 3 -- ...iceDenseTextEmbeddingsServiceSettings.java | 26 +--------- ...eServiceDenseTextEmbeddingsModelTests.java | 1 - ...nseTextEmbeddingsServiceSettingsTests.java | 52 ++----------------- 4 files changed, 5 insertions(+), 77 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 0346cdec38a7c..1a10bd2c1d189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -197,7 +197,6 @@ private static Map initDefaultEndpoints( defaultDenseTextEmbeddingsSimilarity(), null, null, - false, ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS ), EmptyTaskSettings.INSTANCE, @@ -572,14 +571,12 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { var similarityFromModel = serviceSettings.similarity(); var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel; var maxInputTokens = serviceSettings.maxInputTokens(); - var dimensionsSetByUser = serviceSettings.dimensionsSetByUser(); var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( modelId, similarityToUse, embeddingSize, maxInputTokens, - dimensionsSetByUser, serviceSettings.rateLimitSettings() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java index e08f54467cc45..10505c68cdfb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -39,7 +39,6 @@ public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends F ElasticInferenceServiceRateLimitServiceSettings { public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings"; - static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); @@ -47,7 +46,6 @@ public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends F private final SimilarityMeasure similarity; private final Integer dimensions; private final Integer maxInputTokens; - private final boolean dimensionsSetByUser; private final RateLimitSettings rateLimitSettings; public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap( @@ -83,14 +81,11 @@ private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromReq throw validationException; } - var dimensionsSetByUser = dims != null; - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( modelId, similarity, dims, maxInputTokens, - dimensionsSetByUser, rateLimitSettings ); } @@ -113,11 +108,6 @@ private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPer SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); - - if (dimensionsSetByUser == null) { - dimensionsSetByUser = Boolean.FALSE; - } if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -128,7 +118,6 @@ private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPer similarity, dims, maxInputTokens, - dimensionsSetByUser, rateLimitSettings ); } @@ -138,14 +127,12 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, @Nullable Integer maxInputTokens, - boolean dimensionsSetByUser, RateLimitSettings rateLimitSettings ) { this.modelId = modelId; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; - this.dimensionsSetByUser = dimensionsSetByUser; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } @@ -154,7 +141,6 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); - this.dimensionsSetByUser = in.readBoolean(); this.rateLimitSettings = new RateLimitSettings(in); } @@ -182,11 +168,6 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } - @Override - public Boolean dimensionsSetByUser() { - return dimensionsSetByUser; - } - @Override public DenseVectorFieldMapper.ElementType elementType() { return DenseVectorFieldMapper.ElementType.FLOAT; @@ -226,7 +207,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } toXContentFragmentOfExposedFields(builder, params); - builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); builder.endObject(); return builder; @@ -243,7 +223,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - out.writeBoolean(dimensionsSetByUser); rateLimitSettings.writeTo(out); } @@ -252,8 +231,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o; - return dimensionsSetByUser == that.dimensionsSetByUser - && Objects.equals(modelId, that.modelId) + return Objects.equals(modelId, that.modelId) && similarity == that.similarity && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) @@ -262,6 +240,6 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(modelId, similarity, dimensions, maxInputTokens, dimensionsSetByUser, rateLimitSettings); + return Objects.hash(modelId, similarity, dimensions, maxInputTokens, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java index 753e5bb17303d..5719bb94a0be7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -31,7 +31,6 @@ public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel( SimilarityMeasure.COSINE, null, null, - false, new RateLimitSettings(1000L) ), EmptyTaskSettings.INSTANCE, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index 6a9dfb02d13bc..4c661d8fed39b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -70,46 +70,6 @@ public void testFromMap_Request_WithAllSettings() { assertThat(serviceSettings.similarity(), is(similarity)); assertThat(serviceSettings.dimensions(), is(dimensions)); assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); - assertThat(serviceSettings.dimensionsSetByUser(), is(true)); // dimensions were provided - } - - public void testFromMap_Persistent_WithDimensionsSetByUser() { - var modelId = "my-dense-model-id"; - var similarity = SimilarityMeasure.DOT_PRODUCT; - var dimensions = 768; - var dimensionsSetByUser = true; - - var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.MODEL_ID, - modelId, - ServiceFields.SIMILARITY, - similarity.toString(), - ServiceFields.DIMENSIONS, - dimensions, - ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, - dimensionsSetByUser - ) - ), - ConfigurationParseContext.PERSISTENT - ); - - assertThat(serviceSettings.modelId(), is(modelId)); - assertThat(serviceSettings.similarity(), is(similarity)); - assertThat(serviceSettings.dimensions(), is(dimensions)); - assertThat(serviceSettings.dimensionsSetByUser(), is(dimensionsSetByUser)); - } - - public void testFromMap_Persistent_WithoutDimensionsSetByUser_DefaultsToFalse() { - var modelId = "my-dense-model-id"; - - var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), - ConfigurationParseContext.PERSISTENT - ); - - assertThat(serviceSettings.dimensionsSetByUser(), is(false)); } public void testToXContent_WritesAllFields() throws IOException { @@ -117,7 +77,6 @@ public void testToXContent_WritesAllFields() throws IOException { var similarity = SimilarityMeasure.DOT_PRODUCT; var dimensions = 1024; var maxInputTokens = 256; - var dimensionsSetByUser = true; var rateLimitSettings = new RateLimitSettings(5000); var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( @@ -125,7 +84,6 @@ public void testToXContent_WritesAllFields() throws IOException { similarity, dimensions, maxInputTokens, - dimensionsSetByUser, rateLimitSettings ); @@ -138,13 +96,12 @@ public void testToXContent_WritesAllFields() throws IOException { is( Strings.format( """ - {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":%s}""", + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", similarity, dimensions, maxInputTokens, modelId, - rateLimitSettings.requestsPerTimeUnit(), - dimensionsSetByUser + rateLimitSettings.requestsPerTimeUnit() ) ) ); @@ -159,7 +116,6 @@ public void testToXContent_WritesOnlyNonNullFields() throws IOException { null, // similarity null, // dimensions null, // maxInputTokens - false, // dimensionsSetByUser rateLimitSettings ); @@ -172,7 +128,7 @@ public void testToXContent_WritesOnlyNonNullFields() throws IOException { is( Strings.format( """ - {"model_id":"%s","rate_limit":{"requests_per_minute":%d},"dimensions_set_by_user":false}""", + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit() ) @@ -189,7 +145,6 @@ public void testToXContentFragmentOfExposedFields() throws IOException { SimilarityMeasure.COSINE, 512, 128, - true, rateLimitSettings ); @@ -217,7 +172,6 @@ public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRa similarity, dimensions, maxInputTokens, - dimensionsSetByUser, rateLimitSettings ); } From 3489a099921884b5a01ab154215d5098bff5125b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 23 Jun 2025 18:17:22 +0000 Subject: [PATCH 13/23] [CI] Auto commit changes from spotless --- ...erviceDenseTextEmbeddingsServiceSettings.java | 16 ++-------------- ...eDenseTextEmbeddingsServiceSettingsTests.java | 15 +++------------ 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java index 10505c68cdfb5..c62f9e4a71e59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -81,13 +81,7 @@ private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromReq throw validationException; } - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - modelId, - similarity, - dims, - maxInputTokens, - rateLimitSettings - ); + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); } private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap( @@ -113,13 +107,7 @@ private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPer throw validationException; } - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - modelId, - similarity, - dims, - maxInputTokens, - rateLimitSettings - ); + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); } public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index 4c661d8fed39b..7486367247b65 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -96,7 +96,7 @@ public void testToXContent_WritesAllFields() throws IOException { is( Strings.format( """ - {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", similarity, dimensions, maxInputTokens, @@ -123,17 +123,8 @@ public void testToXContent_WritesOnlyNonNullFields() throws IOException { serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat( - xContentResult, - is( - Strings.format( - """ - {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", - modelId, - rateLimitSettings.requestsPerTimeUnit() - ) - ) - ); + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); } public void testToXContentFragmentOfExposedFields() throws IOException { From fb5dbc09ddd3c43f47be32567bac23f443176ff6 Mon Sep 17 00:00:00 2001 From: Brendan Jugan Date: Mon, 23 Jun 2025 16:41:46 -0400 Subject: [PATCH 14/23] fix checkstyle --- ...icInferenceServiceDenseTextEmbeddingsServiceSettings.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java index c62f9e4a71e59..5047f34a1b2e3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -28,7 +28,10 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.*; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; From 1dcbcab4a6367300a2aebef40826aed4df154f1e Mon Sep 17 00:00:00 2001 From: Brendan Jugan Date: Mon, 23 Jun 2025 17:45:05 -0400 Subject: [PATCH 15/23] fix checkstyle --- ...nseTextEmbeddingsServiceSettingsTests.java | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index 7486367247b65..c74bcce3ca3e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -91,18 +91,20 @@ public void testToXContent_WritesAllFields() throws IOException { serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); + String expectedResult = Strings.format( + """ + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", + similarity, + dimensions, + maxInputTokens, + modelId, + rateLimitSettings.requestsPerTimeUnit() + ); + assertThat( xContentResult, is( - Strings.format( - """ - {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", - similarity, - dimensions, - maxInputTokens, - modelId, - rateLimitSettings.requestsPerTimeUnit() - ) + expectedResult ) ); } From dc6f32088ebd2f07d75ea59cc56723f212e73ddc Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 23 Jun 2025 21:54:32 +0000 Subject: [PATCH 16/23] [CI] Auto commit changes from spotless --- ...ceServiceDenseTextEmbeddingsServiceSettingsTests.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index c74bcce3ca3e5..fc22908f7a088 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -93,7 +93,7 @@ public void testToXContent_WritesAllFields() throws IOException { String expectedResult = Strings.format( """ - {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", similarity, dimensions, maxInputTokens, @@ -101,12 +101,7 @@ public void testToXContent_WritesAllFields() throws IOException { rateLimitSettings.requestsPerTimeUnit() ); - assertThat( - xContentResult, - is( - expectedResult - ) - ); + assertThat(xContentResult, is(expectedResult)); } public void testToXContent_WritesOnlyNonNullFields() throws IOException { From 087d4e592ea86f090a96895af2b286e1808a2ef0 Mon Sep 17 00:00:00 2001 From: Brendan Jugan Date: Mon, 23 Jun 2025 23:11:33 -0400 Subject: [PATCH 17/23] use ConstructingObjectParser for response parsing --- ...viceDenseTextEmbeddingsResponseEntity.java | 79 ++++++++++--------- ...ticInferenceServiceActionCreatorTests.java | 6 +- 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java index d80b9c74a6c6c..022ffe787ca6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.inference.external.response.elastic; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; @@ -18,19 +18,12 @@ import org.elasticsearch.xpack.inference.external.request.Request; import java.io.IOException; -import java.util.Collections; import java.util.List; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = - "Failed to find required field [%s] in Elastic Inference Service dense text embeddings response"; - /** * Parses the Elastic Inference Service Dense Text Embeddings response. * @@ -64,43 +57,51 @@ public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { * * */ - public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - moveToFirstToken(jsonParser); - - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { + return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + } + } - positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + public record EmbeddingFloatResult(List embeddingResults) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResult.class.getSimpleName(), + true, + args -> new EmbeddingFloatResult((List) args[0]) + ); - List parsedEmbeddings = parseList( - jsonParser, - (parser, index) -> ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.parseTextEmbeddingObject(parser) + static { + // Custom field declaration to handle array of arrays format + PARSER.declareField( + constructorArg(), + (parser, context) -> { + return XContentParserUtils.parseList(parser, (p, index) -> { + List embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue()); + return EmbeddingFloatResultEntry.fromFloatArray(embedding); + }); + }, + new ParseField("data"), + org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY ); - - if (parsedEmbeddings.isEmpty()) { - return new TextEmbeddingFloatResults(Collections.emptyList()); - } - - return new TextEmbeddingFloatResults(parsedEmbeddings); } - } - private static TextEmbeddingFloatResults.Embedding parseTextEmbeddingObject(XContentParser parser) throws IOException { - List embeddingValueList = parseList( - parser, - ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::parseEmbeddingFloatValueList - ); - return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); + public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { + return new TextEmbeddingFloatResults( + embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + ); + } } - private static float parseEmbeddingFloatValueList(XContentParser parser) throws IOException { - XContentParser.Token token = parser.currentToken(); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); - return parser.floatValue(); + /** + * Represents a single embedding entry in the response. + * For the Elastic Inference Service, each entry is just an array of floats (no wrapper object). + * This is a simpler wrapper that just holds the float array. + */ + public record EmbeddingFloatResultEntry(List embedding) { + public static EmbeddingFloatResultEntry fromFloatArray(List floats) { + return new EmbeddingFloatResultEntry(floats); + } } private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index e63559ffd824d..0e490f11abe02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -48,6 +48,7 @@ import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -407,10 +408,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - thrownException.getMessage(), - is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") - ); + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); From cd3e116266b9af3bdf6f56c9dded0d5a33c946c5 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 24 Jun 2025 03:19:52 +0000 Subject: [PATCH 18/23] [CI] Auto commit changes from spotless --- ...erviceDenseTextEmbeddingsResponseEntity.java | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java index 022ffe787ca6c..a96ebc0048f70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -73,17 +73,12 @@ public record EmbeddingFloatResult(List embeddingResu static { // Custom field declaration to handle array of arrays format - PARSER.declareField( - constructorArg(), - (parser, context) -> { - return XContentParserUtils.parseList(parser, (p, index) -> { - List embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue()); - return EmbeddingFloatResultEntry.fromFloatArray(embedding); - }); - }, - new ParseField("data"), - org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY - ); + PARSER.declareField(constructorArg(), (parser, context) -> { + return XContentParserUtils.parseList(parser, (p, index) -> { + List embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue()); + return EmbeddingFloatResultEntry.fromFloatArray(embedding); + }); + }, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY); } public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { From 7269c519eee28b49d6b8adcd4aee9a8609e26e3b Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 24 Jun 2025 09:42:00 +0200 Subject: [PATCH 19/23] Some cleanup (removing unused vars etc.) --- .../inference/InferenceGetServicesIT.java | 7 +++-- .../elastic/ElasticInferenceServiceTests.java | 27 +++---------------- ...ticInferenceServiceActionCreatorTests.java | 8 +++--- ...eServiceDenseTextEmbeddingsModelTests.java | 7 +---- ...nseTextEmbeddingsServiceSettingsTests.java | 1 - ...erviceDenseTextEmbeddingsRequestTests.java | 2 +- 6 files changed, 13 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index dd52f72cfd456..b96c94db438a7 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -20,6 +20,7 @@ import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @@ -79,15 +80,13 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); assertThat(services.size(), equalTo(18)); - var providers = providers(services); - assertThat( providersFor(TaskType.TEXT_EMBEDDING), containsInAnyOrder( List.of( "alibabacloud-ai-search", "amazonbedrock", - "amazon_sagemaker" + "amazon_sagemaker", "azureaistudio", "azureopenai", "cohere", @@ -102,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "openai", "text_embedding_test_service", "voyageai", - "watsonxai", + "watsonxai" ).toArray() ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d6318b798238a..236a6be3d742d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -58,7 +58,6 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; @@ -89,7 +88,6 @@ import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -748,7 +746,7 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); String productUseCase = "test-product-use-case"; threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); @@ -788,22 +786,8 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException } public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel( - getUrl(webServer), - "my-dense-model-id", - createRandomChunkingSettings() - ); - - testChunkedInfer_BatchesCalls(model); - } + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); - public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); - - testChunkedInfer_BatchesCalls(model); - } - - private void testChunkedInfer_BatchesCalls(ElasticInferenceServiceDenseTextEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createService(senderFactory, getUrl(webServer))) { @@ -871,12 +855,9 @@ private void testChunkedInfer_BatchesCalls(ElasticInferenceServiceDenseTextEmbed assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); } - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().getFirst().getUri().getQuery()); - MatcherAssert.assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) - ); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 0e490f11abe02..c8701b47a20b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -285,7 +285,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); var action = actionCreator.create(model); @@ -341,7 +341,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); var action = actionCreator.create(model); @@ -396,7 +396,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); var action = actionCreator.create(model); @@ -438,7 +438,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); var action = actionCreator.create(model); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java index 5719bb94a0be7..6f920ba62736c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -17,11 +16,7 @@ public class ElasticInferenceServiceDenseTextEmbeddingsModelTests { - public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel( - String url, - String modelId, - ChunkingSettings chunkingSettings - ) { + public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(String url, String modelId) { return new ElasticInferenceServiceDenseTextEmbeddingsModel( "id", TaskType.TEXT_EMBEDDING, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index fc22908f7a088..a9263d5624dca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -152,7 +152,6 @@ public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRa var similarity = SimilarityMeasure.COSINE; var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null; var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; - var dimensionsSetByUser = randomBoolean(); var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null; return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java index 75a6c8a32b65a..86687980acdf6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java @@ -152,7 +152,7 @@ private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( List inputs, InputType inputType ) { - var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId, null); + var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId); return new ElasticInferenceServiceDenseTextEmbeddingsRequest( embeddingsModel, From 220e20853099f3ac869bf0a33f6c9855a4fb3cc5 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 24 Jun 2025 10:19:02 +0200 Subject: [PATCH 20/23] Fix integration test --- .../InferenceRevokeDefaultEndpointsIT.java | 114 +++++++++--------- 1 file changed, 59 insertions(+), 55 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 1aa1f0532a51e..4c200c6f20247 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.mockito.Mockito.mock; public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { @@ -191,19 +192,19 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA String responseJson = """ { "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] }, + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + }, { "model_name": "multilingual-embed-v1", "task_types": ["embed/text/dense"] }, - { + { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] } @@ -219,33 +220,31 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-v2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service + containsInAnyOrder( + new InferenceService.DefaultConfigId( + ".elser-v2-elastic", + MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".rainbow-sprinkles-elastic", + MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ); @@ -257,11 +256,11 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); assertThat( - listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), + listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic") ); + assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); var getModelListener = new PlainActionFuture(); @@ -284,6 +283,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] + }, + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] } ] } @@ -297,32 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertThat( service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-v2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service + containsInAnyOrder( + new InferenceService.DefaultConfigId( + ".elser-v2-elastic", + MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + ); var getModelListener = new PlainActionFuture(); modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); From 27ca440cb242ebe9c56838ef7ad11385e910c6f3 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 24 Jun 2025 10:22:16 +0200 Subject: [PATCH 21/23] Do not set usage context, if it's null --- ...ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java index c149e6cf67063..6d7862f83cb69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java @@ -45,7 +45,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_FIELD, modelId); // optional field - if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) { + if (Objects.nonNull(usageContext) && usageContext != ElasticInferenceServiceUsageContext.UNSPECIFIED) { builder.field(USAGE_CONTEXT, usageContext); } From b7d10b8ea3aec272afe0f09eb327532fb7edf6e1 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 24 Jun 2025 15:57:00 +0200 Subject: [PATCH 22/23] Pass through chunking settings and provide default for default endpoint --- .../services/elastic/ElasticInferenceService.java | 6 ++++-- ...icInferenceServiceDenseTextEmbeddingsModel.java | 14 ++++++++------ ...erenceServiceDenseTextEmbeddingsModelTests.java | 4 +++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 1a10bd2c1d189..834a0fd2d74ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -201,7 +201,8 @@ private static Map initDefaultEndpoints( ), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents + elasticInferenceServiceComponents, + ChunkingSettingsBuilder.DEFAULT_SETTINGS ), MinimalServiceSettings.textEmbedding( NAME, @@ -482,7 +483,8 @@ private static ElasticInferenceServiceModel createModel( taskSettings, secretSettings, elasticInferenceServiceComponents, - context + context, + chunkingSettings ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java index 8c2066ba7046e..dfbfaf47e2d2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; @@ -39,17 +40,18 @@ public ElasticInferenceServiceDenseTextEmbeddingsModel( Map taskSettings, Map secrets, ElasticInferenceServiceComponents elasticInferenceServiceComponents, - ConfigurationParseContext context + ConfigurationParseContext context, + ChunkingSettings chunkingSettings ) { this( inferenceEntityId, taskType, service, ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context), - // TODO: we probably want dense embeddings task settings EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents + elasticInferenceServiceComponents, + chunkingSettings ); } @@ -58,13 +60,13 @@ public ElasticInferenceServiceDenseTextEmbeddingsModel( TaskType taskType, String service, ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings, - // TODO: we probably want dense embeddings task settings @Nullable TaskSettings taskSettings, @Nullable SecretSettings secretSettings, - ElasticInferenceServiceComponents elasticInferenceServiceComponents + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ChunkingSettings chunkingSettings ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings, elasticInferenceServiceComponents diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java index 6f920ba62736c..fe0e4efc85a5b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -30,7 +31,8 @@ public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(String ), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.of(url) + ElasticInferenceServiceComponents.of(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS ); } From fc11815469663bc8444584b4bf6c052e5e6f9c07 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 24 Jun 2025 17:30:35 +0200 Subject: [PATCH 23/23] After merge conflict resolution clean-up --- .../elastic/ElasticInferenceService.java | 48 ++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 0d5c07416bb8a..640929b058760 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -353,21 +353,21 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - if(model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel denseTextEmbeddingsModel){ - var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); - - List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), - DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, - denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() - ).batchRequestsWithListeners(listener); - - for (var request : batchedRequests) { - var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); - } - - return; + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel denseTextEmbeddingsModel) { + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, + denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + + return; } if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) { @@ -423,8 +423,7 @@ public void parseRequestConfig( serviceSettingsMap, elasticInferenceServiceComponents, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), - ConfigurationParseContext.REQUEST, - chunkingSettings + ConfigurationParseContext.REQUEST ); throwIfNotEmptyMap(config, NAME); @@ -461,8 +460,7 @@ private static ElasticInferenceServiceModel createModel( @Nullable Map secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents, String failureMessage, - ConfigurationParseContext context, - ChunkingSettings chunkingSettings + ConfigurationParseContext context ) { return switch (taskType) { case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel( @@ -534,8 +532,7 @@ public Model parsePersistedConfigWithSecrets( taskSettingsMap, chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME), - chunkingSettings + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); } @@ -556,8 +553,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M taskSettingsMap, chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME), - chunkingSettings + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); } @@ -573,8 +569,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, - ChunkingSettings chunkingSettings + String failureMessage ) { return createModel( inferenceEntityId, @@ -585,8 +580,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( secretSettings, elasticInferenceServiceComponents, failureMessage, - ConfigurationParseContext.PERSISTENT, - chunkingSettings + ConfigurationParseContext.PERSISTENT ); }