From 9d5ca1c5f2772edf680398ced196ce18da8b040a Mon Sep 17 00:00:00 2001 From: donalevans Date: Thu, 23 Oct 2025 10:25:26 -0700 Subject: [PATCH 1/6] Do not create inference endpoint if ID is used in existing mappings When creating an inference endpoint, if the inference ID is referenced by ingest pipeline processors or used in semantic_text mappings in non-empty indices, prevent the endpoint from being created. Closes #124272 --- .../inference/MinimalServiceSettings.java | 6 +- .../InferenceProcessorInfoExtractor.java | 5 +- .../xpack/inference/InferenceCrudIT.java | 114 ++++++++- .../mock/TestCompletionServiceExtension.java | 2 +- ...stStreamingCompletionServiceExtension.java | 2 +- .../CreateInferenceEndpointIT.java | 216 ++++++++++++++++++ ...ransportDeleteInferenceEndpointAction.java | 16 +- .../TransportPutInferenceModelAction.java | 93 +++++++- .../inference/LocalStateInferencePlugin.java | 6 +- 9 files changed, 438 insertions(+), 22 deletions(-) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index c05c800fc3424..463e95c16977e 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -62,9 +62,9 @@ public record MinimalServiceSettings( public static final String SERVICE_FIELD = "service"; public static final String TASK_TYPE_FIELD = "task_type"; - static final String DIMENSIONS_FIELD = "dimensions"; - static final String SIMILARITY_FIELD = "similarity"; - static final String ELEMENT_TYPE_FIELD = "element_type"; + public static final String DIMENSIONS_FIELD = "dimensions"; + public static final String SIMILARITY_FIELD = "similarity"; + public static final String ELEMENT_TYPE_FIELD = "element_type"; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "model_settings", diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java index 70c9ecf872e97..951d3f02d2cf5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/InferenceProcessorInfoExtractor.java @@ -116,13 +116,12 @@ public static Map> pipelineIdsByResource(ClusterState state, } /** - * @param state Current {@link ClusterState} + * @param metadata Current cluster state {@link Metadata} * @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them. */ - public static Set pipelineIdsForResource(ClusterState state, Set ids) { + public static Set pipelineIdsForResource(Metadata metadata, Set ids) { assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); Set pipelineIds = new HashSet<>(); - Metadata metadata = state.metadata(); if (metadata == null) { return pipelineIds; } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 0a98787514010..3ad3115f4e754 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference; import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; @@ -211,7 +212,7 @@ public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException final String endpointId = "endpoint_referenced_by_semantic_text"; final String searchEndpointId = "search_endpoint_referenced_by_semantic_text"; final String indexName = randomAlphaOfLength(10).toLowerCase(); - final Function buildErrorString = endpointName -> " Inference endpoint " + final Function buildErrorString = endpointName -> "Inference endpoint " + endpointName + " is being used in the mapping for indexes: " + Set.of(indexName) @@ -303,6 +304,117 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws deleteIndex(indexName); } + public void testCreateEndpoint_withInferenceIdReferencedByPipeline() throws IOException { + String endpointId = "endpoint_referenced_by_pipeline"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var pipelineId1 = "pipeline_referencing_model_1"; + var pipelineId2 = "pipeline_referencing_model_2"; + putPipeline(pipelineId1, endpointId); + putPipeline(pipelineId2, endpointId); + + deleteModel(endpointId, "force=true"); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) + ); + assertThat( + responseException.getMessage(), + containsString("Inference endpoint [" + endpointId + "] could not be created because it is referenced by pipelines: [") + ); + assertThat(responseException.getMessage(), containsString(pipelineId1)); + assertThat(responseException.getMessage(), containsString(pipelineId2)); + + deletePipeline(pipelineId1); + deletePipeline(pipelineId2); + } + + public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException { + final String endpointId = "endpoint_referenced_by_semantic_text"; + final String otherEndpointId = "other_endpoint_referenced_by_semantic_text"; + final String indexName1 = randomAlphaOfLength(10).toLowerCase(); + final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase()); + + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel(otherEndpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + // Create two indices, one where the inference ID of the endpoint we'll be deleting and recreating is used for + // inference_id and one where it's used for search_inference_id + putSemanticText(endpointId, otherEndpointId, indexName1); + putSemanticText(otherEndpointId, endpointId, indexName2); + + // Confirm that we can create the endpoint if there are no documents in the indices using it + deleteModel(endpointId, "force=true"); + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + + // Index a document into each index + var request1 = new Request("PUT", indexName1 + "/_create/1"); + request1.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request1)); + + var request2 = new Request("PUT", indexName2 + "/_create/1"); + request2.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request2)); + + assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); + + deleteModel(endpointId, "force=true"); + + // Try to create an inference endpoint with the same ID + ResponseException responseException = assertThrows( + ResponseException.class, + () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) + ); + assertThat( + responseException.getMessage(), + containsString( + "Inference endpoint [" + endpointId + "] could not be created because it is being used in mappings for indices: [" + ) + ); + assertThat(responseException.getMessage(), containsString(indexName1)); + assertThat(responseException.getMessage(), containsString(indexName2)); + + deleteIndex(indexName1); + deleteIndex(indexName2); + + deleteModel(otherEndpointId, "force=true"); + } + + public void testCreateEndpoint_withInferenceIdReferencedBySemanticTextAndPipeline() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + + // Index a document into the index + var indexDocRequest = new Request("PUT", indexName + "/_create/1"); + indexDocRequest.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(indexDocRequest)); + + assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); + + var pipelineId = "pipeline_referencing_model"; + putPipeline(pipelineId, endpointId); + + deleteModel(endpointId, "force=true"); + + String errorString = "Inference endpoint [" + + endpointId + + "] could not be created because it is being used in mappings for indices: [" + + indexName + + "] and referenced by pipelines: [" + + pipelineId + + "]."; + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) + ); + assertThat(responseException.getMessage(), containsString(errorString)); + + deletePipeline(pipelineId); + deleteIndex(indexName); + } + public void testUnsupportedStream() throws Exception { String modelId = "streaming"; putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service")); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java index 728c39b634bd0..3d1b023984c6b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestCompletionServiceExtension.java @@ -49,7 +49,7 @@ public List getInferenceServiceFactories() { } public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "completion_test_service"; + public static final String NAME = "completion_test_service"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION); public TestInferenceService(InferenceServiceFactoryContext context) {} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 9ea2301abfa0c..422bbcb0beec8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -59,7 +59,7 @@ public List getInferenceServiceFactories() { } public static class TestInferenceService extends AbstractTestInferenceService { - private static final String NAME = "streaming_completion_test_service"; + public static final String NAME = "streaming_completion_test_service"; private static final String ALIAS = "streaming_completion_test_service_alias"; private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java new file mode 100644 index 0000000000000..fe454c13428eb --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java @@ -0,0 +1,216 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.MinimalServiceSettings.DIMENSIONS_FIELD; +import static org.elasticsearch.inference.MinimalServiceSettings.ELEMENT_TYPE_FIELD; +import static org.elasticsearch.inference.MinimalServiceSettings.SIMILARITY_FIELD; +import static org.elasticsearch.inference.SimilarityMeasure.COSINE; +import static org.elasticsearch.inference.SimilarityMeasure.L2_NORM; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; + +@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 +public class CreateInferenceEndpointIT extends ESIntegTestCase { + + public static final String INFERENCE_ID = "inference-id"; + public static final String SEMANTIC_TEXT_FIELD = "semantic-text-field"; + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class); + } + + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andDocumentsInIndex() throws IOException { + Map serviceSettings = getRandomTextEmbeddingServiceSettings(); + String otherInferenceId = "some-other-inference-id"; + assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); + assertEndpointCreationSuccessful(serviceSettings, otherInferenceId); + + List indicesUsingInferenceId = new ArrayList<>(); + List indicesNotUsingInferenceId = new ArrayList<>(); + Set allIndices = new HashSet<>(); + int indexPairsToCreate = 10; + + for (int i = 0; i < indexPairsToCreate; ++i) { + String indexUsingInferenceId = createIndexWithSemanticTextMapping(INFERENCE_ID, allIndices); + indexDocument(indexUsingInferenceId); + indicesUsingInferenceId.add(indexUsingInferenceId); + allIndices.add(indexUsingInferenceId); + + String indexNotUsingInferenceId = createIndexWithSemanticTextMapping(otherInferenceId, allIndices); + indexDocument(indexNotUsingInferenceId); + indicesNotUsingInferenceId.add(indexNotUsingInferenceId); + allIndices.add(indexNotUsingInferenceId); + } + + forceDeleteInferenceEndpoint(); + + ElasticsearchStatusException statusException = expectThrows( + ElasticsearchStatusException.class, + () -> createTextEmbeddingEndpoint(serviceSettings, INFERENCE_ID).actionGet(TEST_REQUEST_TIMEOUT) + ); + + assertThat(statusException.status(), is(RestStatus.BAD_REQUEST)); + assertThat( + statusException.getMessage(), + containsString( + "Inference endpoint [" + INFERENCE_ID + "] could not be created because it is being used in mappings for indices: [" + ) + ); + + // Make sure we only report the indices that were using the inference ID + for (int i = 0; i < indexPairsToCreate; ++i) { + assertThat(statusException.getMessage(), containsString(indicesUsingInferenceId.get(i))); + assertThat(statusException.getMessage(), not(containsString(indicesNotUsingInferenceId.get(i)))); + } + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingThatInferenceIdExists_andNoDocumentsInIndex() + throws IOException { + Map serviceSettings = getRandomTextEmbeddingServiceSettings(); + assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); + + createIndexWithSemanticTextMapping(); + + forceDeleteInferenceEndpoint(); + + assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); + } + + public void testCreateInferenceEndpoint_succeeds_whenIndexIsCreatedBeforeInferenceEndpoint() throws IOException { + createIndexWithSemanticTextMapping(); + + assertEndpointCreationSuccessful(getRandomTextEmbeddingServiceSettings(), INFERENCE_ID); + } + + private void assertEndpointCreationSuccessful(Map serviceSettings, String inferenceId) throws IOException { + assertThat( + createTextEmbeddingEndpoint(serviceSettings, inferenceId).actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), + equalTo(inferenceId) + ); + } + + private ActionFuture createTextEmbeddingEndpoint( + Map serviceSettings, + String inferenceId + ) throws IOException { + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", TestDenseInferenceServiceExtension.TestInferenceService.NAME); + builder.field("service_settings", serviceSettings); + builder.endObject(); + content = BytesReference.bytes(builder); + } + + var request = new PutInferenceModelAction.Request( + TaskType.TEXT_EMBEDDING, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + return client().execute(PutInferenceModelAction.INSTANCE, request); + } + + private void createIndexWithSemanticTextMapping() throws IOException { + createIndexWithSemanticTextMapping(CreateInferenceEndpointIT.INFERENCE_ID, Set.of()); + } + + private String createIndexWithSemanticTextMapping(String inferenceId, Set existingIndexNames) throws IOException { + // Ensure that all index names are unique + String indexName = randomValueOtherThanMany( + existingIndexNames::contains, + () -> ESTestCase.randomAlphaOfLength(10).toLowerCase(Locale.ROOT) + ); + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); + mapping.startObject(SEMANTIC_TEXT_FIELD); + mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field("inference_id", inferenceId); + mapping.endObject().endObject().endObject(); + + assertAcked(prepareCreate(indexName).setMapping(mapping)); + return indexName; + } + + private static void indexDocument(String indexName) { + Map source = Map.of(SEMANTIC_TEXT_FIELD, randomAlphaOfLength(10)); + DocWriteResponse response = client().prepareIndex(indexName).setSource(source).get(TEST_REQUEST_TIMEOUT); + assertThat(response.getResult(), is(DocWriteResponse.Result.CREATED)); + client().admin().indices().prepareRefresh(indexName).get(); + } + + private static Map getRandomTextEmbeddingServiceSettings() { + Map settings = new HashMap<>(); + settings.put("model", "my_model"); + settings.put("api_key", "my_api_key"); + // Always use a dimension that's a multiple of 8 because the BIT element type requires that + settings.put(DIMENSIONS_FIELD, randomIntBetween(8, 128) * 8); + if (randomBoolean()) { + settings.put(ELEMENT_TYPE_FIELD, randomFrom(ElementType.values()).toString()); + } + if (randomBoolean()) { + // We can't use the DOT_PRODUCT similarity measure because it only works with unit-length vectors, which + // the TestDenseInferenceServiceExtension does not produce + settings.put(SIMILARITY_FIELD, randomFrom(COSINE, L2_NORM).toString()); + } + // The only supported similarity measure for BIT vectors is L2_NORM + if (ElementType.BIT.toString().equals(settings.get(ELEMENT_TYPE_FIELD))) { + settings.put(SIMILARITY_FIELD, L2_NORM.toString()); + } + return settings; + } + + private void forceDeleteInferenceEndpoint() { + var request = new DeleteInferenceEndpointAction.Request(INFERENCE_ID, TaskType.TEXT_EMBEDDING, true, false); + var responseFuture = client().execute(DeleteInferenceEndpointAction.INSTANCE, request); + responseFuture.actionGet(TEST_REQUEST_TIMEOUT); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index c100c9926b451..f94251947b6b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -31,13 +31,13 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor; import org.elasticsearch.xpack.inference.common.InferenceExceptions; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.Set; import java.util.concurrent.Executor; +import static org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsForResource; import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; @@ -195,12 +195,9 @@ private static void handleDryRun( ClusterState state, ActionListener masterListener ) { - Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); + Set pipelines = endpointIsReferencedInPipelines(state, request.getInferenceEndpointId()); - Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( - state.getMetadata(), - Set.of(request.getInferenceEndpointId()) - ); + Set indexesReferencedBySemanticText = endpointIsReferencedInIndex(state, request.getInferenceEndpointId()); masterListener.onResponse( new DeleteInferenceEndpointAction.Response( @@ -241,7 +238,10 @@ private static String buildErrorString(String inferenceEndpointId, Set p } if (indexes.isEmpty() == false) { - errorString.append(" Inference endpoint ") + if (errorString.isEmpty() == false) { + errorString.append(" "); + } + errorString.append("Inference endpoint ") .append(inferenceEndpointId) .append(" is being used in the mapping for indexes: ") .append(indexes) @@ -258,7 +258,7 @@ private static Set endpointIsReferencedInIndex(final ClusterState state, } private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { - return InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(inferenceEndpointId)); + return pipelineIdsForResource(state.metadata(), Set.of(inferenceEndpointId)); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 80d57f888ef6e..97751df6aaea1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -11,11 +11,17 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.project.ProjectResolver; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; @@ -32,6 +38,7 @@ import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -49,11 +56,17 @@ import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsForResource; +import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< @@ -65,6 +78,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< private final XPackLicenseState licenseState; private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; + private final OriginSettingClient client; private volatile boolean skipValidationAndStart; private final ProjectResolver projectResolver; @@ -78,7 +92,8 @@ public TransportPutInferenceModelAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Settings settings, - ProjectResolver projectResolver + ProjectResolver projectResolver, + Client client ) { super( PutInferenceModelAction.NAME, @@ -97,6 +112,7 @@ public TransportPutInferenceModelAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart); this.projectResolver = projectResolver; + this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); } @Override @@ -181,7 +197,15 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); + parseAndStoreModel( + service.get(), + request.getInferenceEntityId(), + resolvedTaskType, + requestAsMap, + request.getTimeout(), + state.metadata(), + listener + ); } private void parseAndStoreModel( @@ -190,6 +214,7 @@ private void parseAndStoreModel( TaskType taskType, Map config, TimeValue timeout, + Metadata metadata, ActionListener listener ) { ActionListener storeModelListener = listener.delegateFailureAndWrap( @@ -212,7 +237,7 @@ private void parseAndStoreModel( ) ); - ActionListener parsedModelListener = listener.delegateFailureAndWrap((delegate, model) -> { + ActionListener modelValidatingListener = listener.delegateFailureAndWrap((delegate, model) -> { if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { @@ -221,7 +246,67 @@ private void parseAndStoreModel( } }); - service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener); + ActionListener existingUsesListener = listener.delegateFailureAndWrap((delegate, model) -> { + // Execute in another thread because checking for existing uses requires reading from indices + threadPool.executor(UTILITY_THREAD_POOL_NAME) + .execute(() -> checkForExistingUsesOfInferenceId(metadata, model, modelValidatingListener)); + }); + + service.parseRequestConfig(inferenceEntityId, taskType, config, existingUsesListener); + } + + private void checkForExistingUsesOfInferenceId(Metadata metadata, Model model, ActionListener modelValidatingListener) { + Set inferenceEntityIdSet = Set.of(model.getInferenceEntityId()); + Set nonEmptyIndices = findNonEmptyIndices(extractIndexesReferencingInferenceEndpoints(metadata, inferenceEntityIdSet)); + Set pipelinesUsingInferenceId = pipelineIdsForResource(metadata, inferenceEntityIdSet); + + if (nonEmptyIndices.isEmpty() && pipelinesUsingInferenceId.isEmpty()) { + modelValidatingListener.onResponse(model); + } else { + modelValidatingListener.onFailure( + new ElasticsearchStatusException( + buildErrorString(model.getInferenceEntityId(), nonEmptyIndices, pipelinesUsingInferenceId), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private HashSet findNonEmptyIndices(Set indicesUsingInferenceId) { + var nonEmptyIndices = new HashSet(); + if (indicesUsingInferenceId.isEmpty() == false) { + // Search for documents in the indices + for (String indexName : indicesUsingInferenceId) { + SearchRequest countRequest = new SearchRequest(indexName); + countRequest.indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); + countRequest.allowPartialSearchResults(true); + // We just need to know whether any documents exist at all + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true).trackTotalHitsUpTo(1); + countRequest.source(searchSourceBuilder); + SearchResponse searchResponse = client.search(countRequest).actionGet(); + if (searchResponse.getHits().getTotalHits().value() > 0) { + nonEmptyIndices.add(indexName); + } + searchResponse.decRef(); + } + } + return nonEmptyIndices; + } + + private static String buildErrorString(String inferenceId, Set nonEmptyIndices, Set pipelinesUsingInferenceId) { + StringBuilder errorString = new StringBuilder(); + errorString.append("Inference endpoint [").append(inferenceId).append("] could not be created because it is "); + if (nonEmptyIndices.isEmpty() == false) { + errorString.append("being used in mappings for indices: ").append(nonEmptyIndices).append(" "); + } + if (pipelinesUsingInferenceId.isEmpty() == false) { + if (nonEmptyIndices.isEmpty() == false) { + errorString.append("and "); + } + errorString.append("referenced by pipelines: ").append(pipelinesUsingInferenceId); + } + errorString.append("."); + return errorString.toString(); } private void startInferenceEndpoint( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java index c2253c7f5424b..3a14cf6a851ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java @@ -16,9 +16,11 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.inference.mock.TestCompletionServiceExtension; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension; import java.nio.file.Path; import java.util.Collection; @@ -49,7 +51,9 @@ public List getInferenceServiceFactories() { return List.of( TestSparseInferenceServiceExtension.TestInferenceService::new, TestDenseInferenceServiceExtension.TestInferenceService::new, - TestRerankingServiceExtension.TestInferenceService::new + TestRerankingServiceExtension.TestInferenceService::new, + TestCompletionServiceExtension.TestInferenceService::new, + TestStreamingCompletionServiceExtension.TestInferenceService::new ); } }; From 98a009191c17578bdab0b0b61c54fef1497748ba Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Thu, 23 Oct 2025 10:59:45 -0700 Subject: [PATCH 2/6] Update docs/changelog/137055.yaml --- docs/changelog/137055.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/137055.yaml diff --git a/docs/changelog/137055.yaml b/docs/changelog/137055.yaml new file mode 100644 index 0000000000000..e2e0581a5f5ed --- /dev/null +++ b/docs/changelog/137055.yaml @@ -0,0 +1,6 @@ +pr: 137055 +summary: Do not create inference endpoint if ID is used in existing mappings +area: Machine Learning +type: bug +issues: + - 124272 From 5fb900e17ac0a23b120ff08f38c6f6fa853bf601 Mon Sep 17 00:00:00 2001 From: donalevans Date: Thu, 23 Oct 2025 11:18:55 -0700 Subject: [PATCH 3/6] Reword error message --- .../xpack/inference/InferenceCrudIT.java | 29 ++++++++++++++++--- .../CreateInferenceEndpointIT.java | 4 ++- .../TransportPutInferenceModelAction.java | 9 ++++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 3ad3115f4e754..4e74bcf3bdcf8 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -320,10 +320,21 @@ public void testCreateEndpoint_withInferenceIdReferencedByPipeline() throws IOEx ); assertThat( responseException.getMessage(), - containsString("Inference endpoint [" + endpointId + "] could not be created because it is referenced by pipelines: [") + containsString( + "Inference endpoint [" + + endpointId + + "] could not be created because the inference_id is already referenced by pipelines: [" + ) ); assertThat(responseException.getMessage(), containsString(pipelineId1)); assertThat(responseException.getMessage(), containsString(pipelineId2)); + assertThat( + responseException.getMessage(), + containsString( + "Please either use a different inference_id or update the index mappings " + + "and/or pipelines to refer to a different inference_id." + ) + ); deletePipeline(pipelineId1); deletePipeline(pipelineId2); @@ -367,11 +378,20 @@ public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws assertThat( responseException.getMessage(), containsString( - "Inference endpoint [" + endpointId + "] could not be created because it is being used in mappings for indices: [" + "Inference endpoint [" + + endpointId + + "] could not be created because the inference_id is already being used in mappings for indices: [" ) ); assertThat(responseException.getMessage(), containsString(indexName1)); assertThat(responseException.getMessage(), containsString(indexName2)); + assertThat( + responseException.getMessage(), + containsString( + "Please either use a different inference_id or update the index mappings " + + "and/or pipelines to refer to a different inference_id." + ) + ); deleteIndex(indexName1); deleteIndex(indexName2); @@ -399,11 +419,12 @@ public void testCreateEndpoint_withInferenceIdReferencedBySemanticTextAndPipelin String errorString = "Inference endpoint [" + endpointId - + "] could not be created because it is being used in mappings for indices: [" + + "] could not be created because the inference_id is already being used in mappings for indices: [" + indexName + "] and referenced by pipelines: [" + pipelineId - + "]."; + + "]. Please either use a different inference_id or update the index mappings " + + "and/or pipelines to refer to a different inference_id."; ResponseException responseException = assertThrows( ResponseException.class, diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java index fe454c13428eb..7de4be4b9b342 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java @@ -101,7 +101,9 @@ public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInfer assertThat( statusException.getMessage(), containsString( - "Inference endpoint [" + INFERENCE_ID + "] could not be created because it is being used in mappings for indices: [" + "Inference endpoint [" + + INFERENCE_ID + + "] could not be created because the inference_id is already being used in mappings for indices: [" ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 97751df6aaea1..cc8357efe2ba3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -295,7 +295,9 @@ private HashSet findNonEmptyIndices(Set indicesUsingInferenceId) private static String buildErrorString(String inferenceId, Set nonEmptyIndices, Set pipelinesUsingInferenceId) { StringBuilder errorString = new StringBuilder(); - errorString.append("Inference endpoint [").append(inferenceId).append("] could not be created because it is "); + errorString.append("Inference endpoint [") + .append(inferenceId) + .append("] could not be created because the inference_id is already "); if (nonEmptyIndices.isEmpty() == false) { errorString.append("being used in mappings for indices: ").append(nonEmptyIndices).append(" "); } @@ -305,7 +307,10 @@ private static String buildErrorString(String inferenceId, Set nonEmptyI } errorString.append("referenced by pipelines: ").append(pipelinesUsingInferenceId); } - errorString.append("."); + errorString.append( + ". Please either use a different inference_id or update the index mappings " + + "and/or pipelines to refer to a different inference_id." + ); return errorString.toString(); } From e95e06e9eb1f621e142b1c541d9c65f417df43a6 Mon Sep 17 00:00:00 2001 From: donalevans Date: Thu, 23 Oct 2025 13:17:10 -0700 Subject: [PATCH 4/6] Fix failing test --- .../TransportPutInferenceModelAction.java | 4 +- ..._text_query_inference_endpoint_changes.yml | 95 +------------------ 2 files changed, 7 insertions(+), 92 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index cc8357efe2ba3..430d65e8b9fa0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -299,11 +299,11 @@ private static String buildErrorString(String inferenceId, Set nonEmptyI .append(inferenceId) .append("] could not be created because the inference_id is already "); if (nonEmptyIndices.isEmpty() == false) { - errorString.append("being used in mappings for indices: ").append(nonEmptyIndices).append(" "); + errorString.append("being used in mappings for indices: ").append(nonEmptyIndices); } if (pipelinesUsingInferenceId.isEmpty() == false) { if (nonEmptyIndices.isEmpty() == false) { - errorString.append("and "); + errorString.append(" and "); } errorString.append("referenced by pipelines: ").append(pipelinesUsingInferenceId); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index 51595d40737a3..1e8c7bca78499 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -77,89 +77,14 @@ setup: non_inference_field: "non inference test" refresh: true --- -"sparse_embedding changed to text_embedding": - - do: - inference.delete: - inference_id: sparse-inference-id - force: true - - - do: - inference.put: - task_type: text_embedding - inference_id: sparse-inference-id - body: > - { - "service": "text_embedding_test_service", - "service_settings": { - "model": "my_model", - "dimensions": 10, - "api_key": "abc64", - "similarity": "COSINE" - }, - "task_settings": { - } - } - - - do: - catch: bad_request - search: - index: test-sparse-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_expansion_result], got [text_embedding_result]. Has the configuration for - inference endpoint [sparse-inference-id] changed?" } - ---- -"text_embedding changed to sparse_embedding": +"create endpoint fails when the inference_id is in use": - do: inference.delete: inference_id: dense-inference-id force: true - - do: - inference.put: - task_type: sparse_embedding - inference_id: dense-inference-id - body: > - { - "service": "test_service", - "service_settings": { - "model": "my_model", - "api_key": "abc64" - }, - "task_settings": { - } - } - - do: catch: bad_request - search: - index: test-dense-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_embedding_result], got [text_expansion_result]. Has the configuration for - inference endpoint [dense-inference-id] changed?" } - ---- -"text_embedding dimension count changed": - - do: - inference.delete: - inference_id: dense-inference-id - force: true - - - do: inference.put: task_type: text_embedding inference_id: dense-inference-id @@ -176,17 +101,7 @@ setup: } } - - do: - catch: bad_request - search: - index: test-dense-index - body: - query: - semantic: - field: "inference_field" - query: "inference test" - - - match: { error.caused_by.type: "illegal_argument_exception" } - - match: { error.caused_by.reason: "Field [inference_field] expected query inference results with 10 dimensions, got - 20 dimensions. Has the configuration for inference endpoint [dense-inference-id] - changed?" } + - match: { error.reason: "Inference endpoint [dense-inference-id] could not be created because the inference_id + is already being used in mappings for indices: [test-dense-index]. Please either use + a different inference_id or update the index mappings and/or pipelines to refer to a + different inference_id." } From a6fba1712d46bfefb573a7655450ee4ba6178474 Mon Sep 17 00:00:00 2001 From: donalevans Date: Thu, 30 Oct 2025 17:35:23 -0700 Subject: [PATCH 5/6] Check model settings for existing semantic text fields - Do not check for pipelines using the inference ID - Check if existing semantic text fields have compatible model settings - Update and expand test coverage for the new behaviour - Improve existing test InferenceServiceExtension implementations - Move SemanticTextInfoExtractor from xpack.core.ml.utils to xpack.inference.common --- .../ml/utils/SemanticTextInfoExtractor.java | 46 --- .../inference/InferenceBaseRestTest.java | 16 + .../xpack/inference/InferenceCrudIT.java | 122 ++------ .../TestDenseInferenceServiceExtension.java | 38 ++- .../mock/TestRerankingServiceExtension.java | 5 +- .../TestSparseInferenceServiceExtension.java | 7 +- ...stStreamingCompletionServiceExtension.java | 11 +- .../CreateInferenceEndpointIT.java | 279 +++++++++++++----- ...ransportDeleteInferenceEndpointAction.java | 2 +- .../TransportPutInferenceModelAction.java | 69 ++--- .../common/SemanticTextInfoExtractor.java | 75 +++++ .../inference/mapper/SemanticTextField.java | 6 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../settings/DefaultSecretSettings.java | 2 +- ..._text_query_inference_endpoint_changes.yml | 6 +- 15 files changed, 399 insertions(+), 287 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java deleted file mode 100644 index d65e0117027a9..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.core.ml.utils; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.transport.Transports; - -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -public class SemanticTextInfoExtractor { - private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); - - public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { - assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); - assert endpointIds.isEmpty() == false; - assert metadata != null; - - Set referenceIndices = new HashSet<>(); - - Map indices = metadata.getProject().indices(); - - indices.forEach((indexName, indexMetadata) -> { - Map inferenceFields = indexMetadata.getInferenceFields(); - if (inferenceFields.values() - .stream() - .anyMatch(im -> endpointIds.contains(im.getInferenceId()) || endpointIds.contains(im.getSearchInferenceId()))) { - referenceIndices.add(indexName); - } - }); - - return referenceIndices; - } -} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 69256d49fe1d2..2c833186df0f0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() { """; } + static String mockDenseServiceModelConfig(int dimensions) { + return Strings.format(""" + { + "task_type": "text_embedding", + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_dense_vector_model", + "api_key": "abc64", + "dimensions": %s + }, + "task_settings": { + } + } + """, dimensions); + } + static String mockRerankServiceModelConfig() { return """ { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 4e74bcf3bdcf8..bf5c10233119d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -304,93 +304,66 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws deleteIndex(indexName); } - public void testCreateEndpoint_withInferenceIdReferencedByPipeline() throws IOException { - String endpointId = "endpoint_referenced_by_pipeline"; - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - var pipelineId1 = "pipeline_referencing_model_1"; - var pipelineId2 = "pipeline_referencing_model_2"; - putPipeline(pipelineId1, endpointId); - putPipeline(pipelineId2, endpointId); - - deleteModel(endpointId, "force=true"); - - ResponseException responseException = assertThrows( - ResponseException.class, - () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) - ); - assertThat( - responseException.getMessage(), - containsString( - "Inference endpoint [" - + endpointId - + "] could not be created because the inference_id is already referenced by pipelines: [" - ) - ); - assertThat(responseException.getMessage(), containsString(pipelineId1)); - assertThat(responseException.getMessage(), containsString(pipelineId2)); - assertThat( - responseException.getMessage(), - containsString( - "Please either use a different inference_id or update the index mappings " - + "and/or pipelines to refer to a different inference_id." - ) - ); - - deletePipeline(pipelineId1); - deletePipeline(pipelineId2); - } - public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws IOException { final String endpointId = "endpoint_referenced_by_semantic_text"; final String otherEndpointId = "other_endpoint_referenced_by_semantic_text"; final String indexName1 = randomAlphaOfLength(10).toLowerCase(); final String indexName2 = randomValueOtherThan(indexName1, () -> randomAlphaOfLength(10).toLowerCase()); - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - putModel(otherEndpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - // Create two indices, one where the inference ID of the endpoint we'll be deleting and recreating is used for - // inference_id and one where it's used for search_inference_id + putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING); + putModel(otherEndpointId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); + // Create two indices, one where the inference ID of the endpoint we'll be deleting and + // recreating is used for inference_id and one where it's used for search_inference_id putSemanticText(endpointId, otherEndpointId, indexName1); putSemanticText(otherEndpointId, endpointId, indexName2); - // Confirm that we can create the endpoint if there are no documents in the indices using it + // Confirm that we can create the endpoint with different settings if there + // are documents in the indices which do not use the semantic text field + var request = new Request("PUT", indexName1 + "/_create/1"); + request.setJsonEntity("{\"non_inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + request = new Request("PUT", indexName2 + "/_create/1"); + request.setJsonEntity("{\"non_inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); + + assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); + deleteModel(endpointId, "force=true"); - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + putModel(endpointId, mockDenseServiceModelConfig(64), TaskType.TEXT_EMBEDDING); - // Index a document into each index - var request1 = new Request("PUT", indexName1 + "/_create/1"); - request1.setJsonEntity("{\"inference_field\": \"value\"}"); - assertStatusOkOrCreated(client().performRequest(request1)); + // Index a document with the semantic text field into each index + request = new Request("PUT", indexName1 + "/_create/2"); + request.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); - var request2 = new Request("PUT", indexName2 + "/_create/1"); - request2.setJsonEntity("{\"inference_field\": \"value\"}"); - assertStatusOkOrCreated(client().performRequest(request2)); + request = new Request("PUT", indexName2 + "/_create/2"); + request.setJsonEntity("{\"inference_field\": \"value\"}"); + assertStatusOkOrCreated(client().performRequest(request)); assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); deleteModel(endpointId, "force=true"); - // Try to create an inference endpoint with the same ID + // Try to create an inference endpoint with the same ID but different dimensions + // from when the document with the semantic text field was indexed ResponseException responseException = assertThrows( ResponseException.class, - () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) + () -> putModel(endpointId, mockDenseServiceModelConfig(128), TaskType.TEXT_EMBEDDING) ); assertThat( responseException.getMessage(), containsString( "Inference endpoint [" + endpointId - + "] could not be created because the inference_id is already being used in mappings for indices: [" + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: [" ) ); assertThat(responseException.getMessage(), containsString(indexName1)); assertThat(responseException.getMessage(), containsString(indexName2)); assertThat( responseException.getMessage(), - containsString( - "Please either use a different inference_id or update the index mappings " - + "and/or pipelines to refer to a different inference_id." - ) + containsString("Please either use a different inference_id or update the index mappings to refer to a different inference_id.") ); deleteIndex(indexName1); @@ -399,43 +372,6 @@ public void testCreateEndpoint_withInferenceIdReferencedBySemanticText() throws deleteModel(otherEndpointId, "force=true"); } - public void testCreateEndpoint_withInferenceIdReferencedBySemanticTextAndPipeline() throws IOException { - String endpointId = "endpoint_referenced_by_semantic_text"; - putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); - String indexName = randomAlphaOfLength(10).toLowerCase(); - putSemanticText(endpointId, indexName); - - // Index a document into the index - var indexDocRequest = new Request("PUT", indexName + "/_create/1"); - indexDocRequest.setJsonEntity("{\"inference_field\": \"value\"}"); - assertStatusOkOrCreated(client().performRequest(indexDocRequest)); - - assertStatusOkOrCreated(client().performRequest(new Request("GET", "_refresh"))); - - var pipelineId = "pipeline_referencing_model"; - putPipeline(pipelineId, endpointId); - - deleteModel(endpointId, "force=true"); - - String errorString = "Inference endpoint [" - + endpointId - + "] could not be created because the inference_id is already being used in mappings for indices: [" - + indexName - + "] and referenced by pipelines: [" - + pipelineId - + "]. Please either use a different inference_id or update the index mappings " - + "and/or pipelines to refer to a different inference_id."; - - ResponseException responseException = assertThrows( - ResponseException.class, - () -> putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING) - ); - assertThat(responseException.getMessage(), containsString(errorString)); - - deletePipeline(pipelineId); - deleteIndex(indexName); - } - public void testUnsupportedStream() throws Exception { String modelId = "streaming"; putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service")); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 45ecb3dedf3f1..dd63facfe7dda 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -170,7 +170,12 @@ public void chunkedInfer( private DenseEmbeddingFloatResults makeResults(List input, ServiceSettings serviceSettings) { List embeddings = new ArrayList<>(); for (String inputString : input) { - List floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType()); + List floatEmbeddings = generateEmbedding( + inputString, + serviceSettings.dimensions(), + serviceSettings.elementType(), + serviceSettings.similarity() + ); embeddings.add(DenseEmbeddingFloatResults.Embedding.of(floatEmbeddings)); } return new DenseEmbeddingFloatResults(embeddings); @@ -206,7 +211,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS *
    *
  • Unique to the input
  • *
  • Reproducible (i.e given the same input, the same embedding should be generated)
  • - *
  • Valid for the provided element type
  • + *
  • Valid for the provided element type and similarity measure
  • *
*

* The embedding is generated by: @@ -216,6 +221,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS *

  • converting the hash code value to a string
  • *
  • converting the string to a UTF-8 encoded byte array
  • *
  • repeatedly appending the byte array to the embedding until the desired number of dimensions are populated
  • + *
  • converting the embedding to a unit vector if the similarity measure requires that
  • * *

    * Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8 @@ -226,11 +232,17 @@ protected ServiceSettings getServiceSettingsFromMap(Map serviceS * embedding byte. *

    * - * @param input The input string - * @param dimensions The embedding dimension count + * @param input The input string + * @param dimensions The embedding dimension count + * @param similarityMeasure The similarity measure * @return An embedding */ - private static List generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) { + private static List generateEmbedding( + String input, + int dimensions, + DenseVectorFieldMapper.ElementType elementType, + SimilarityMeasure similarityMeasure + ) { int embeddingLength = getEmbeddingLength(elementType, dimensions); List embedding = new ArrayList<>(embeddingLength); @@ -248,6 +260,9 @@ private static List generateEmbedding(String input, int dimensions, Dense if (remainingLength > 0) { embedding.addAll(embeddingValues.subList(0, remainingLength)); } + if (similarityMeasure == SimilarityMeasure.DOT_PRODUCT) { + embedding = toUnitVector(embedding); + } return embedding; } @@ -263,6 +278,11 @@ private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType element }; } + private static List toUnitVector(List embedding) { + var magnitude = (float) Math.sqrt(embedding.stream().reduce(0f, (a, b) -> a + (b * b))); + return embedding.stream().map(v -> v / magnitude).toList(); + } + public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); @@ -304,9 +324,13 @@ public record TestServiceSettings( public static TestServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); - String model = (String) map.remove("model"); + String model = (String) map.remove("model_id"); + if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } Integer dimensions = (Integer) map.remove("dimensions"); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 962fc9e1ee818..919c2338b3b88 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -318,7 +318,10 @@ public static TestServiceSettings fromMap(Map map) { String model = (String) map.remove("model_id"); if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } if (validationException.validationErrors().isEmpty() == false) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 1a7a95536c755..2c6606e720e56 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -268,10 +268,13 @@ public record TestServiceSettings(String model, String hiddenField, boolean shou public static TestServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); - String model = (String) map.remove("model"); + String model = (String) map.remove("model_id"); if (model == null) { - validationException.addValidationError("missing model"); + model = (String) map.remove("model"); + if (model == null) { + validationException.addValidationError("missing model"); + } } String hiddenField = (String) map.remove("hidden_field"); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 422bbcb0beec8..f07f5eb01d47b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -343,12 +343,15 @@ public TestServiceSettings(StreamInput in) throws IOException { } public static TestServiceSettings fromMap(Map map) { - var modelId = map.remove("model").toString(); + String modelId = (String) map.remove("model_id"); if (modelId == null) { - ValidationException validationException = new ValidationException(); - validationException.addValidationError("missing model id"); - throw validationException; + modelId = (String) map.remove("model"); + if (modelId == null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("missing model id"); + throw validationException; + } } return new TestServiceSettings(modelId); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java index 7de4be4b9b342..27e8d089de03a 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CreateInferenceEndpointIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; @@ -25,14 +26,19 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.mock.TestRerankingServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestStreamingCompletionServiceExtension; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -40,12 +46,18 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.inference.MinimalServiceSettings.DIMENSIONS_FIELD; -import static org.elasticsearch.inference.MinimalServiceSettings.ELEMENT_TYPE_FIELD; -import static org.elasticsearch.inference.MinimalServiceSettings.SIMILARITY_FIELD; +import static org.elasticsearch.inference.ModelConfigurations.SERVICE; +import static org.elasticsearch.inference.ModelConfigurations.SERVICE_SETTINGS; import static org.elasticsearch.inference.SimilarityMeasure.COSINE; import static org.elasticsearch.inference.SimilarityMeasure.L2_NORM; +import static org.elasticsearch.inference.TaskType.ANY; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.ELEMENT_TYPE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings.API_KEY; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; @@ -55,7 +67,8 @@ public class CreateInferenceEndpointIT extends ESIntegTestCase { public static final String INFERENCE_ID = "inference-id"; - public static final String SEMANTIC_TEXT_FIELD = "semantic-text-field"; + public static final String NOT_MODIFIED_INFERENCE_ID = "not-modified-inference-id"; + public static final String SEMANTIC_TEXT_FIELD_NAME = "semantic-text-field"; @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @@ -67,34 +80,65 @@ protected Collection> nodePlugins() { return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class); } - public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andDocumentsInIndex() throws IOException { - Map serviceSettings = getRandomTextEmbeddingServiceSettings(); - String otherInferenceId = "some-other-inference-id"; - assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); - assertEndpointCreationSuccessful(serviceSettings, otherInferenceId); + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andTaskTypeIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(true, null); + } - List indicesUsingInferenceId = new ArrayList<>(); - List indicesNotUsingInferenceId = new ArrayList<>(); - Set allIndices = new HashSet<>(); - int indexPairsToCreate = 10; + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andDimensionsAreIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, DIMENSIONS); + } - for (int i = 0; i < indexPairsToCreate; ++i) { - String indexUsingInferenceId = createIndexWithSemanticTextMapping(INFERENCE_ID, allIndices); - indexDocument(indexUsingInferenceId); - indicesUsingInferenceId.add(indexUsingInferenceId); - allIndices.add(indexUsingInferenceId); + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andElementTypeIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, ELEMENT_TYPE); + } - String indexNotUsingInferenceId = createIndexWithSemanticTextMapping(otherInferenceId, allIndices); - indexDocument(indexNotUsingInferenceId); - indicesNotUsingInferenceId.add(indexNotUsingInferenceId); - allIndices.add(indexNotUsingInferenceId); - } + public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInferenceIdExists_andSimilarityIsIncompatible() + throws IOException { + modifyEndpointAndAssertFailure(false, SIMILARITY); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andAllSettingsAreTheSame() + throws IOException { + modifyEndpointAndAssertSuccess(null, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andModelIdIsDifferent() + throws IOException { + modifyEndpointAndAssertSuccess(MODEL_ID, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingTheInferenceIdExists_andApiKeyIsDifferent() + throws IOException { + modifyEndpointAndAssertSuccess(API_KEY, true); + } + + public void testCreateInferenceEndpoint_succeeds_whenNoDocumentsUsingSemanticTextHaveBeenIndexed() throws IOException { + String fieldToModify = randomFrom(DIMENSIONS, ELEMENT_TYPE, SIMILARITY); + modifyEndpointAndAssertSuccess(fieldToModify, false); + } + + public void testCreateInferenceEndpoint_succeeds_whenIndexIsCreatedBeforeInferenceEndpoint() throws IOException { + String inferenceId = NOT_MODIFIED_INFERENCE_ID; + String indexName = createIndexWithSemanticTextMapping(inferenceId); + + assertEndpointCreationSuccessful(randomTaskType(), getRandomServiceSettings(), inferenceId); + + IntegrationTestUtils.deleteIndex(client(), indexName); + } - forceDeleteInferenceEndpoint(); + private void modifyEndpointAndAssertFailure(boolean modifyTaskType, String settingsFieldToModify) throws IOException { + TaskType taskType = TEXT_EMBEDDING; + Map serviceSettings = getRandomServiceSettings(); + Set indicesUsingInferenceId = new HashSet<>(); + + String indexNotUsingInferenceId = indexDocumentsAndDeleteEndpoint(taskType, serviceSettings, indicesUsingInferenceId, true); ElasticsearchStatusException statusException = expectThrows( ElasticsearchStatusException.class, - () -> createTextEmbeddingEndpoint(serviceSettings, INFERENCE_ID).actionGet(TEST_REQUEST_TIMEOUT) + () -> createEndpointWithModifiedSettings(modifyTaskType, settingsFieldToModify, taskType, serviceSettings) ); assertThat(statusException.status(), is(RestStatus.BAD_REQUEST)); @@ -103,67 +147,117 @@ public void testCreateInferenceEndpoint_fails_whenSemanticTextFieldUsingTheInfer containsString( "Inference endpoint [" + INFERENCE_ID - + "] could not be created because the inference_id is already being used in mappings for indices: [" + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: [" ) ); // Make sure we only report the indices that were using the inference ID - for (int i = 0; i < indexPairsToCreate; ++i) { - assertThat(statusException.getMessage(), containsString(indicesUsingInferenceId.get(i))); - assertThat(statusException.getMessage(), not(containsString(indicesNotUsingInferenceId.get(i)))); - } + indicesUsingInferenceId.forEach(index -> assertThat(statusException.getMessage(), containsString(index))); + assertThat(statusException.getMessage(), not(containsString(indexNotUsingInferenceId))); + + indicesUsingInferenceId.forEach(index -> IntegrationTestUtils.deleteIndex(client(), index)); + IntegrationTestUtils.deleteIndex(client(), indexNotUsingInferenceId); } - public void testCreateInferenceEndpoint_succeeds_whenSemanticTextFieldUsingThatInferenceIdExists_andNoDocumentsInIndex() - throws IOException { - Map serviceSettings = getRandomTextEmbeddingServiceSettings(); - assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); + private void modifyEndpointAndAssertSuccess(String fieldToModify, boolean documentHasSemanticText) throws IOException { + TaskType taskType = TEXT_EMBEDDING; + Map serviceSettings = getRandomServiceSettings(); + HashSet indicesUsingInferenceId = new HashSet<>(); - createIndexWithSemanticTextMapping(); + String indexNotUsingInferenceId = indexDocumentsAndDeleteEndpoint( + taskType, + serviceSettings, + indicesUsingInferenceId, + documentHasSemanticText + ); - forceDeleteInferenceEndpoint(); + PutInferenceModelAction.Response response = createEndpointWithModifiedSettings(false, fieldToModify, taskType, serviceSettings); + assertThat(response.getModel().getInferenceEntityId(), equalTo(INFERENCE_ID)); - assertEndpointCreationSuccessful(serviceSettings, INFERENCE_ID); + indicesUsingInferenceId.forEach(index -> IntegrationTestUtils.deleteIndex(client(), index)); + IntegrationTestUtils.deleteIndex(client(), indexNotUsingInferenceId); } - public void testCreateInferenceEndpoint_succeeds_whenIndexIsCreatedBeforeInferenceEndpoint() throws IOException { - createIndexWithSemanticTextMapping(); + private String indexDocumentsAndDeleteEndpoint( + TaskType taskType, + Map serviceSettings, + Set indicesUsingInferenceId, + boolean documentHasSemanticText + ) throws IOException { + assertEndpointCreationSuccessful(taskType, serviceSettings, INFERENCE_ID); + assertEndpointCreationSuccessful(taskType, serviceSettings, NOT_MODIFIED_INFERENCE_ID); + + // Create several indices to confirm that we can identify them all in the error message + for (int i = 0; i < 5; ++i) { + String indexUsingInferenceId = createIndexWithSemanticTextMapping(INFERENCE_ID, indicesUsingInferenceId); + indexDocument(indexUsingInferenceId, documentHasSemanticText); + indicesUsingInferenceId.add(indexUsingInferenceId); + } + + // Also create a second endpoint which will not be deleted and recreated, and an index which is using it + String indexNotUsingInferenceId = createIndexWithSemanticTextMapping(NOT_MODIFIED_INFERENCE_ID, indicesUsingInferenceId); + indexDocument(indexNotUsingInferenceId, documentHasSemanticText); + + forceDeleteInferenceEndpoint(INFERENCE_ID, taskType); + return indexNotUsingInferenceId; + } - assertEndpointCreationSuccessful(getRandomTextEmbeddingServiceSettings(), INFERENCE_ID); + private PutInferenceModelAction.Response createEndpointWithModifiedSettings( + boolean modifyTaskType, + String fieldToModify, + TaskType taskType, + Map serviceSettings + ) { + TaskType newTaskType = modifyTaskType ? randomValueOtherThan(taskType, CreateInferenceEndpointIT::randomTaskType) : taskType; + Map newSettings = fieldToModify != null ? modifyServiceSettings(serviceSettings, fieldToModify) : serviceSettings; + return createEndpoint(newTaskType, newSettings, INFERENCE_ID).actionGet(TEST_REQUEST_TIMEOUT); } - private void assertEndpointCreationSuccessful(Map serviceSettings, String inferenceId) throws IOException { + private void assertEndpointCreationSuccessful(TaskType taskType, Map serviceSettings, String inferenceId) { assertThat( - createTextEmbeddingEndpoint(serviceSettings, inferenceId).actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), + createEndpoint(taskType, serviceSettings, inferenceId).actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId) ); } - private ActionFuture createTextEmbeddingEndpoint( + private ActionFuture createEndpoint( + TaskType taskType, Map serviceSettings, String inferenceId - ) throws IOException { + ) { final BytesReference content; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { builder.startObject(); - builder.field("service", TestDenseInferenceServiceExtension.TestInferenceService.NAME); - builder.field("service_settings", serviceSettings); + builder.field(SERVICE, getServiceForTaskType(taskType)); + builder.field(SERVICE_SETTINGS, serviceSettings); builder.endObject(); content = BytesReference.bytes(builder); + } catch (IOException ex) { + throw new AssertionError(ex); } - var request = new PutInferenceModelAction.Request( - TaskType.TEXT_EMBEDDING, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); + var request = new PutInferenceModelAction.Request(taskType, inferenceId, content, XContentType.JSON, TEST_REQUEST_TIMEOUT); return client().execute(PutInferenceModelAction.INSTANCE, request); } - private void createIndexWithSemanticTextMapping() throws IOException { - createIndexWithSemanticTextMapping(CreateInferenceEndpointIT.INFERENCE_ID, Set.of()); + private String getServiceForTaskType(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + case RERANK -> TestRerankingServiceExtension.TestInferenceService.NAME; + case COMPLETION, CHAT_COMPLETION -> TestStreamingCompletionServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalStateException("Unexpected value: " + taskType); + }; + } + + private static TaskType randomTaskType() { + EnumSet taskTypes = EnumSet.allOf(TaskType.class); + taskTypes.remove(ANY); + return randomFrom(taskTypes); + } + + private String createIndexWithSemanticTextMapping(String inferenceId) throws IOException { + return createIndexWithSemanticTextMapping(inferenceId, Set.of()); } private String createIndexWithSemanticTextMapping(String inferenceId, Set existingIndexNames) throws IOException { @@ -172,46 +266,73 @@ private String createIndexWithSemanticTextMapping(String inferenceId, Set ESTestCase.randomAlphaOfLength(10).toLowerCase(Locale.ROOT) ); - XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); - mapping.startObject(SEMANTIC_TEXT_FIELD); - mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mapping.field("inference_id", inferenceId); + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject(ElasticsearchMappings.PROPERTIES); + mapping.startObject(SEMANTIC_TEXT_FIELD_NAME); + mapping.field(ElasticsearchMappings.TYPE, SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field(SemanticTextField.INFERENCE_ID_FIELD, inferenceId); mapping.endObject().endObject().endObject(); assertAcked(prepareCreate(indexName).setMapping(mapping)); return indexName; } - private static void indexDocument(String indexName) { - Map source = Map.of(SEMANTIC_TEXT_FIELD, randomAlphaOfLength(10)); + private static void indexDocument(String indexName, boolean withSemanticText) { + var source = new HashMap(); + source.put("field", "value"); + if (withSemanticText) { + source.put(SEMANTIC_TEXT_FIELD_NAME, randomAlphaOfLength(10)); + } DocWriteResponse response = client().prepareIndex(indexName).setSource(source).get(TEST_REQUEST_TIMEOUT); assertThat(response.getResult(), is(DocWriteResponse.Result.CREATED)); client().admin().indices().prepareRefresh(indexName).get(); } - private static Map getRandomTextEmbeddingServiceSettings() { + private static Map getRandomServiceSettings() { Map settings = new HashMap<>(); - settings.put("model", "my_model"); - settings.put("api_key", "my_api_key"); + settings.put(MODEL_ID, randomIdentifier()); + settings.put(API_KEY, randomIdentifier()); // Always use a dimension that's a multiple of 8 because the BIT element type requires that - settings.put(DIMENSIONS_FIELD, randomIntBetween(8, 128) * 8); - if (randomBoolean()) { - settings.put(ELEMENT_TYPE_FIELD, randomFrom(ElementType.values()).toString()); - } - if (randomBoolean()) { - // We can't use the DOT_PRODUCT similarity measure because it only works with unit-length vectors, which - // the TestDenseInferenceServiceExtension does not produce - settings.put(SIMILARITY_FIELD, randomFrom(COSINE, L2_NORM).toString()); - } - // The only supported similarity measure for BIT vectors is L2_NORM - if (ElementType.BIT.toString().equals(settings.get(ELEMENT_TYPE_FIELD))) { - settings.put(SIMILARITY_FIELD, L2_NORM.toString()); + settings.put(DIMENSIONS, randomIntBetween(1, 32) * 8); + ElementType elementType = randomFrom(ElementType.values()); + settings.put(ELEMENT_TYPE, elementType.toString()); + if (elementType == ElementType.BIT) { + // The only supported similarity measure for BIT vectors is L2_NORM + settings.put(SIMILARITY, L2_NORM.toString()); + } else if (elementType == ElementType.BYTE) { + // DOT_PRODUCT similarity does not work with BYTE due to how TestDenseInferenceServiceExtension creates embeddings + settings.put(SIMILARITY, randomFrom(L2_NORM, COSINE).toString()); + } else { + settings.put(SIMILARITY, randomFrom(SimilarityMeasure.values()).toString()); } return settings; } - private void forceDeleteInferenceEndpoint() { - var request = new DeleteInferenceEndpointAction.Request(INFERENCE_ID, TaskType.TEXT_EMBEDDING, true, false); + private static Map modifyServiceSettings(Map serviceSettings, String fieldToModify) { + var newServiceSettings = new HashMap<>(serviceSettings); + switch (fieldToModify) { + case MODEL_ID, API_KEY -> newServiceSettings.compute( + fieldToModify, + (k, value) -> randomValueOtherThan(value, ESTestCase::randomIdentifier) + ); + case DIMENSIONS -> newServiceSettings.compute( + DIMENSIONS, + (k, dimensions) -> randomValueOtherThan(dimensions, () -> randomIntBetween(8, 128) * 8) + ); + case ELEMENT_TYPE -> newServiceSettings.compute( + ELEMENT_TYPE, + (k, elementType) -> randomValueOtherThan(elementType, () -> randomFrom(ElementType.values()).toString()) + ); + case SIMILARITY -> newServiceSettings.compute( + SIMILARITY, + (k, similarity) -> randomValueOtherThan(similarity, () -> randomFrom(SimilarityMeasure.values()).toString()) + ); + default -> throw new AssertionError("Invalid service settings field " + fieldToModify); + } + return newServiceSettings; + } + + private void forceDeleteInferenceEndpoint(String inferenceId, TaskType taskType) { + var request = new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false); var responseFuture = client().execute(DeleteInferenceEndpointAction.INSTANCE, request); responseFuture.actionGet(TEST_REQUEST_TIMEOUT); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index f94251947b6b7..512ad5a445b18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -38,8 +38,8 @@ import java.util.concurrent.Executor; import static org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsForResource; -import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 430d65e8b9fa0..f7a563d9bfed9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -11,10 +11,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; @@ -28,9 +25,11 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.StrictDynamicMappingException; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; @@ -38,7 +37,6 @@ import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -63,10 +61,10 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsForResource; -import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.getModelSettingsForIndicesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< @@ -257,61 +255,40 @@ private void parseAndStoreModel( private void checkForExistingUsesOfInferenceId(Metadata metadata, Model model, ActionListener modelValidatingListener) { Set inferenceEntityIdSet = Set.of(model.getInferenceEntityId()); - Set nonEmptyIndices = findNonEmptyIndices(extractIndexesReferencingInferenceEndpoints(metadata, inferenceEntityIdSet)); - Set pipelinesUsingInferenceId = pipelineIdsForResource(metadata, inferenceEntityIdSet); + Set indicesWithIncompatibleMappings = findIndicesWithIncompatibleMappings(model, metadata, inferenceEntityIdSet); - if (nonEmptyIndices.isEmpty() && pipelinesUsingInferenceId.isEmpty()) { + if (indicesWithIncompatibleMappings.isEmpty()) { modelValidatingListener.onResponse(model); } else { modelValidatingListener.onFailure( new ElasticsearchStatusException( - buildErrorString(model.getInferenceEntityId(), nonEmptyIndices, pipelinesUsingInferenceId), + buildErrorString(model.getInferenceEntityId(), indicesWithIncompatibleMappings), RestStatus.BAD_REQUEST ) ); } } - private HashSet findNonEmptyIndices(Set indicesUsingInferenceId) { - var nonEmptyIndices = new HashSet(); - if (indicesUsingInferenceId.isEmpty() == false) { - // Search for documents in the indices - for (String indexName : indicesUsingInferenceId) { - SearchRequest countRequest = new SearchRequest(indexName); - countRequest.indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); - countRequest.allowPartialSearchResults(true); - // We just need to know whether any documents exist at all - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true).trackTotalHitsUpTo(1); - countRequest.source(searchSourceBuilder); - SearchResponse searchResponse = client.search(countRequest).actionGet(); - if (searchResponse.getHits().getTotalHits().value() > 0) { - nonEmptyIndices.add(indexName); + private Set findIndicesWithIncompatibleMappings(Model model, Metadata metadata, Set inferenceEntityIdSet) { + var serviceSettingsMap = getModelSettingsForIndicesReferencingInferenceEndpoints(metadata, inferenceEntityIdSet); + var incompatibleIndices = new HashSet(); + if (serviceSettingsMap.isEmpty() == false) { + MinimalServiceSettings newSettings = new MinimalServiceSettings(model); + serviceSettingsMap.forEach((indexName, existingSettings) -> { + if (canMergeModelSettings(existingSettings, newSettings, new FieldMapper.Conflicts("")) == false) { + incompatibleIndices.add(indexName); } - searchResponse.decRef(); - } + }); } - return nonEmptyIndices; + return incompatibleIndices; } - private static String buildErrorString(String inferenceId, Set nonEmptyIndices, Set pipelinesUsingInferenceId) { - StringBuilder errorString = new StringBuilder(); - errorString.append("Inference endpoint [") - .append(inferenceId) - .append("] could not be created because the inference_id is already "); - if (nonEmptyIndices.isEmpty() == false) { - errorString.append("being used in mappings for indices: ").append(nonEmptyIndices); - } - if (pipelinesUsingInferenceId.isEmpty() == false) { - if (nonEmptyIndices.isEmpty() == false) { - errorString.append(" and "); - } - errorString.append("referenced by pipelines: ").append(pipelinesUsingInferenceId); - } - errorString.append( - ". Please either use a different inference_id or update the index mappings " - + "and/or pipelines to refer to a different inference_id." - ); - return errorString.toString(); + private static String buildErrorString(String inferenceId, Set indicesWithIncompatibleMappings) { + return "Inference endpoint [" + + inferenceId + + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: " + + indicesWithIncompatibleMappings + + ". Please either use a different inference_id or update the index mappings to refer to a different inference_id."; } private void startInferenceEndpoint( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java new file mode 100644 index 0000000000000..e4ca1c7dbfa93 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SemanticTextInfoExtractor.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.MappingMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.transport.Transports; +import org.elasticsearch.xcontent.ObjectPath; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class SemanticTextInfoExtractor { + public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Set referenceIndices = new HashSet<>(); + + Map indices = metadata.getProject().indices(); + + indices.forEach((indexName, indexMetadata) -> { + Map inferenceFields = indexMetadata.getInferenceFields(); + if (inferenceFields.values() + .stream() + .anyMatch(im -> endpointIds.contains(im.getInferenceId()) || endpointIds.contains(im.getSearchInferenceId()))) { + referenceIndices.add(indexName); + } + }); + + return referenceIndices; + } + + public static Map getModelSettingsForIndicesReferencingInferenceEndpoints( + Metadata metadata, + Set endpointIds + ) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Map serviceSettingsMap = new HashMap<>(); + + metadata.getProject().indices().forEach((indexName, indexMetadata) -> { + indexMetadata.getInferenceFields() + .values() + .stream() + .filter(field -> endpointIds.contains(field.getInferenceId()) || endpointIds.contains(field.getSearchInferenceId())) + .findFirst() // Assume that the model settings are the same for all fields using the inference endpoint + .ifPresent(field -> { + MappingMetadata mapping = indexMetadata.mapping(); + if (mapping != null) { + String[] pathArray = { ElasticsearchMappings.PROPERTIES, field.getName(), SemanticTextField.MODEL_SETTINGS_FIELD }; + Object modelSettings = ObjectPath.eval(pathArray, mapping.sourceAsMap()); + serviceSettingsMap.put(indexName, SemanticTextField.parseModelSettingsFromMap(modelSettings)); + } + }); + }); + + return serviceSettingsMap; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index ba455a6ca8af0..eaece2974ba64 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -62,7 +62,7 @@ public record SemanticTextField( static final String TEXT_FIELD = "text"; static final String INFERENCE_FIELD = "inference"; - static final String INFERENCE_ID_FIELD = "inference_id"; + public static final String INFERENCE_ID_FIELD = "inference_id"; static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; static final String CHUNKS_FIELD = "chunks"; static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; @@ -70,7 +70,7 @@ public record SemanticTextField( static final String CHUNKED_OFFSET_FIELD = "offset"; static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; - static final String MODEL_SETTINGS_FIELD = "model_settings"; + public static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; public record InferenceResult( @@ -108,7 +108,7 @@ static SemanticTextField parse(XContentParser parser, ParserContext context) thr return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); } - static MinimalServiceSettings parseModelSettingsFromMap(Object node) { + public static MinimalServiceSettings parseModelSettingsFromMap(Object node) { if (node == null) { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 6caefd09b6c59..aef57c6755f9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1470,7 +1470,7 @@ static SemanticTextIndexOptions defaultIndexOptions(IndexVersion indexVersionCre return null; } - private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { + public static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) { if (previous != null && current != null && previous.canMergeWith(current)) { return true; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java index d076c946889ed..745d6f585a137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/DefaultSecretSettings.java @@ -36,7 +36,7 @@ public record DefaultSecretSettings(SecureString apiKey) implements SecretSettings, ApiKeySecrets { public static final String NAME = "default_secret_settings"; - static final String API_KEY = "api_key"; + public static final String API_KEY = "api_key"; public static DefaultSecretSettings fromMap(@Nullable Map map) { if (map == null) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index 1e8c7bca78499..01c91012beff7 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -77,7 +77,7 @@ setup: non_inference_field: "non inference test" refresh: true --- -"create endpoint fails when the inference_id is in use": +"create endpoint fails when the inference_id is used by a semantic text field and is incompatible": - do: inference.delete: inference_id: dense-inference-id @@ -102,6 +102,6 @@ setup: } - match: { error.reason: "Inference endpoint [dense-inference-id] could not be created because the inference_id - is already being used in mappings for indices: [test-dense-index]. Please either use - a different inference_id or update the index mappings and/or pipelines to refer to a + is being used in mappings with incompatible settings for indices: [test-dense-index]. + Please either use a different inference_id or update the index mappings to refer to a different inference_id." } From d566301b589504f4a2e50761d965820391d750ff Mon Sep 17 00:00:00 2001 From: donalevans Date: Mon, 3 Nov 2025 08:49:19 -0800 Subject: [PATCH 6/6] Restore default visibility to MinimalServiceSettings constants --- .../org/elasticsearch/inference/MinimalServiceSettings.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java index 463e95c16977e..c05c800fc3424 100644 --- a/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java @@ -62,9 +62,9 @@ public record MinimalServiceSettings( public static final String SERVICE_FIELD = "service"; public static final String TASK_TYPE_FIELD = "task_type"; - public static final String DIMENSIONS_FIELD = "dimensions"; - public static final String SIMILARITY_FIELD = "similarity"; - public static final String ELEMENT_TYPE_FIELD = "element_type"; + static final String DIMENSIONS_FIELD = "dimensions"; + static final String SIMILARITY_FIELD = "similarity"; + static final String ELEMENT_TYPE_FIELD = "element_type"; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "model_settings",