From 762b6d471c7ed6ec2417569a5d0441eb82282459 Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Wed, 15 Oct 2025 15:15:54 -0700 Subject: [PATCH] [ML] Clean up inference indices on failed endpoint creation (#136577) Prior to this change, if only one of the .inference or .secrets-inference indices was updated when creating an inference endpoint, the endpoint creation would fail, but the successfully written doc was not removed, leading to inconsistent document counts between the two indices. This commit removes any documents that were written in the case that a partial failure occurred, but does not change the behaviour in the case where no updates to the indices were made. - Invoke a cleanup listener if a partial failure occurred when storing inference endpoint information in the .inference and .secrets-inference indices - Refactor ModelRegistry to use BulkRequestBuilder.add(IndexRequestBuilder) instead of the deprecated BulkRequestBuilder.add(IndexRequest) - Include cause when logging bulk failure during inference endpoint creation - Add integration tests for the new behaviour - Update docs/changelog/136577.yaml Closes #123726 (cherry picked from commit 9abc0bd7b5c990af4a24b795dcac0a3ed2104498) --- docs/changelog/136577.yaml | 6 + .../inference/registry/ModelRegistry.java | 128 ++++++++++++------ .../registry/ModelRegistryTests.java | 102 +++++++++++++- 3 files changed, 191 insertions(+), 45 deletions(-) create mode 100644 docs/changelog/136577.yaml diff --git a/docs/changelog/136577.yaml b/docs/changelog/136577.yaml new file mode 100644 index 0000000000000..e48d831a41ff5 --- /dev/null +++ b/docs/changelog/136577.yaml @@ -0,0 +1,6 @@ +pr: 136577 +summary: Clean up inference indices on failed endpoint creation +area: Machine Learning +type: bug +issues: + - 123726 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7cd1cf5999d11..a042839b23d4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.GroupedActionListener; @@ -531,11 +532,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi SubscribableListener.newForked((subListener) -> { // in this block, we try to update the stored model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, newModel.getConfigurations(), - true + true, + client ); ActionListener storeConfigListener = subListener.delegateResponse((l, e) -> { @@ -544,7 +546,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); - client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + client.prepareBulk() + .add(configRequestBuilder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute(storeConfigListener); }).andThen((subListener, configResponse) -> { // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets @@ -569,11 +574,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi ); } else { // Since the model configurations were successfully updated, we can now try to store the new secrets - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(newModel.getConfigurations().getInferenceEntityId()), + var secretsRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceSecretsIndex.INDEX_NAME, newModel.getSecrets(), - true + true, + client ); ActionListener storeSecretsListener = subListener.delegateResponse((l, e) -> { @@ -583,7 +589,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi }); client.prepareBulk() - .add(secretsRequest) + .add(secretsRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(storeSecretsListener); } @@ -591,12 +597,14 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi // in this block, we respond to the success or failure of updating the model secrets if (secretsResponse.hasFailures()) { // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, existingModel.getConfigurations(), - true + true, + client ); + logger.error( "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state", inferenceEntityId @@ -608,7 +616,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); client.prepareBulk() - .add(configRequest) + .add(configRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(rollbackConfigListener); } else { @@ -655,24 +663,25 @@ public void storeModel(Model model, ActionListener listener, TimeValue private void storeModel(Model model, boolean updateClusterState, ActionListener listener, TimeValue timeout) { ActionListener bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout); - - IndexRequest configRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), + String inferenceEntityId = model.getConfigurations().getInferenceEntityId(); + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, model.getConfigurations(), - false + false, + client ); - - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), + var secretsRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), - false + false, + client ); client.prepareBulk() - .add(configRequest) - .add(secretsRequest) + .add(configRequestBuilder) + .add(secretsRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(bulkResponseActionListener); } @@ -683,15 +692,24 @@ private ActionListener getStoreIndexListener( ActionListener listener, TimeValue timeout ) { + // If there was a partial failure in writing to the indices, we need to clean up + AtomicBoolean partialFailure = new AtomicBoolean(false); + var cleanupListener = listener.delegateResponse((delegate, ex) -> { + if (partialFailure.get()) { + deleteModel(model.getInferenceEntityId(), ActionListener.running(() -> delegate.onFailure(ex))); + } else { + delegate.onFailure(ex); + } + }); return ActionListener.wrap(bulkItemResponses -> { - var inferenceEntityId = model.getConfigurations().getInferenceEntityId(); + var inferenceEntityId = model.getInferenceEntityId(); if (bulkItemResponses.getItems().length == 0) { logger.warn( format("Storing inference endpoint [%s] failed, no items were received from the bulk response", inferenceEntityId) ); - listener.onFailure( + cleanupListener.onFailure( new ElasticsearchStatusException( format( "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service", @@ -707,7 +725,7 @@ private ActionListener getStoreIndexListener( if (failure == null) { if (updateClusterState) { - var storeListener = getStoreMetadataListener(inferenceEntityId, listener); + var storeListener = getStoreMetadataListener(inferenceEntityId, cleanupListener); try { metadataTaskQueue.submitTask( "add model [" + inferenceEntityId + "]", @@ -723,19 +741,22 @@ private ActionListener getStoreIndexListener( storeListener.onFailure(exc); } } else { - listener.onResponse(Boolean.TRUE); + cleanupListener.onResponse(Boolean.TRUE); } return; } - logBulkFailures(model.getConfigurations().getInferenceEntityId(), bulkItemResponses); + for (BulkItemResponse aResponse : bulkItemResponses.getItems()) { + logBulkFailure(inferenceEntityId, aResponse); + partialFailure.compareAndSet(false, aResponse.isFailed() == false); + } if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) { - listener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); + cleanupListener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); return; } - listener.onFailure( + cleanupListener.onFailure( new ElasticsearchStatusException( format("Failed to store inference endpoint [%s]", inferenceEntityId), RestStatus.INTERNAL_SERVER_ERROR, @@ -743,9 +764,9 @@ private ActionListener getStoreIndexListener( ) ); }, e -> { - String errorMessage = format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId()); + String errorMessage = format("Failed to store inference endpoint [%s]", model.getInferenceEntityId()); logger.warn(errorMessage, e); - listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); + cleanupListener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); }); } @@ -779,18 +800,12 @@ public void onFailure(Exception exc) { }; } - private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) { - for (BulkItemResponse item : bulkResponse.getItems()) { - if (item.isFailed()) { - logger.warn( - format( - "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]", - inferenceEntityId, - item.getIndex(), - item.getFailureMessage() - ) - ); - } + private static void logBulkFailure(String inferenceEntityId, BulkItemResponse item) { + if (item.isFailed()) { + logger.warn( + format("Failed to store inference endpoint [%s] index: [%s]", inferenceEntityId, item.getIndex()), + item.getFailure().getCause() + ); } } @@ -937,6 +952,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T } } + static IndexRequestBuilder createIndexRequestBuilder( + String inferenceId, + String indexName, + ToXContentObject body, + boolean allowOverwriting, + Client client + ) { + try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) { + XContentBuilder source = body.toXContent( + xContentBuilder, + new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) + ); + + return new IndexRequestBuilder(client).setIndex(indexName) + .setCreate(allowOverwriting == false) + .setId(Model.documentId(inferenceId)) + .setSource(source); + } catch (IOException ex) { + throw new ElasticsearchException( + "Unexpected serialization exception for index [{}] inference ID [{}]", + ex, + indexName, + inferenceId + ); + } + } + private static UnparsedModel modelToUnparsedModel(Model model) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { model.getConfigurations() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index eee8550ec6524..2980621e52c6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -10,9 +10,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.MinimalServiceSettingsTests; @@ -21,15 +26,18 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.hamcrest.Matchers; import org.junit.Before; import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -52,7 +60,7 @@ public class ModelRegistryTests extends ESSingleNodeTestCase { @Override protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } @Before @@ -237,6 +245,96 @@ public void testDuplicateDefaultIds() { ); } + public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(false); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(true); + } + + public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { + var model = new TestModel( + "model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + PlainActionFuture firstStoreListener = new PlainActionFuture<>(); + registry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); + firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); + + assertIndicesContainExpectedDocsCount(model, 2); + + PlainActionFuture secondStoreListener = new PlainActionFuture<>(); + registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); + + expectThrows(ResourceAlreadyExistsException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); + + assertIndicesContainExpectedDocsCount(model, 2); + } + + private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { + var model = new TestModel( + "corrupted-model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + storeCorruptedModel(model, storeSecrets); + + assertIndicesContainExpectedDocsCount(model, 1); + + PlainActionFuture storeListener = new PlainActionFuture<>(); + registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); + + expectThrows(ResourceAlreadyExistsException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + + assertIndicesContainExpectedDocsCount(model, 0); + } + + private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { + SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) + .setSize(2) + .setTrackTotalHits(false) + .request(); + SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); + try { + assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); + } finally { + searchResponse.decRef(); + } + } + + private void storeCorruptedModel(Model model, boolean storeSecrets) { + var listener = new PlainActionFuture(); + + client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add( + ModelRegistry.createIndexRequestBuilder( + model.getInferenceEntityId(), + storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, + storeSecrets ? model.getSecrets() : model.getConfigurations(), + false, + client() + ) + ) + .execute(listener); + + var bulkResponse = listener.actionGet(TIMEOUT); + if (bulkResponse.hasFailures()) { + fail("Failed to store model: " + bulkResponse.buildFailureMessage()); + } + } + public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);