diff --git a/docs/changelog/136577.yaml b/docs/changelog/136577.yaml new file mode 100644 index 0000000000000..e48d831a41ff5 --- /dev/null +++ b/docs/changelog/136577.yaml @@ -0,0 +1,6 @@ +pr: 136577 +summary: Clean up inference indices on failed endpoint creation +area: Machine Learning +type: bug +issues: + - 123726 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 7cd1cf5999d11..a042839b23d4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.GroupedActionListener; @@ -531,11 +532,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi SubscribableListener.newForked((subListener) -> { // in this block, we try to update the stored model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, newModel.getConfigurations(), - true + true, + client ); ActionListener storeConfigListener = subListener.delegateResponse((l, e) -> { @@ -544,7 +546,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); - client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener); + client.prepareBulk() + .add(configRequestBuilder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute(storeConfigListener); }).andThen((subListener, configResponse) -> { // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets @@ -569,11 +574,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi ); } else { // Since the model configurations were successfully updated, we can now try to store the new secrets - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(newModel.getConfigurations().getInferenceEntityId()), + var secretsRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceSecretsIndex.INDEX_NAME, newModel.getSecrets(), - true + true, + client ); ActionListener storeSecretsListener = subListener.delegateResponse((l, e) -> { @@ -583,7 +589,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi }); client.prepareBulk() - .add(secretsRequest) + .add(secretsRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(storeSecretsListener); } @@ -591,12 +597,14 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi // in this block, we respond to the success or failure of updating the model secrets if (secretsResponse.hasFailures()) { // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations - IndexRequest configRequest = createIndexRequest( - Model.documentId(inferenceEntityId), + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, existingModel.getConfigurations(), - true + true, + client ); + logger.error( "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state", inferenceEntityId @@ -608,7 +616,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi l.onFailure(e); }); client.prepareBulk() - .add(configRequest) + .add(configRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(rollbackConfigListener); } else { @@ -655,24 +663,25 @@ public void storeModel(Model model, ActionListener listener, TimeValue private void storeModel(Model model, boolean updateClusterState, ActionListener listener, TimeValue timeout) { ActionListener bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout); - - IndexRequest configRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), + String inferenceEntityId = model.getConfigurations().getInferenceEntityId(); + var configRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceIndex.INDEX_NAME, model.getConfigurations(), - false + false, + client ); - - IndexRequest secretsRequest = createIndexRequest( - Model.documentId(model.getConfigurations().getInferenceEntityId()), + var secretsRequestBuilder = createIndexRequestBuilder( + inferenceEntityId, InferenceSecretsIndex.INDEX_NAME, model.getSecrets(), - false + false, + client ); client.prepareBulk() - .add(configRequest) - .add(secretsRequest) + .add(configRequestBuilder) + .add(secretsRequestBuilder) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .execute(bulkResponseActionListener); } @@ -683,15 +692,24 @@ private ActionListener getStoreIndexListener( ActionListener listener, TimeValue timeout ) { + // If there was a partial failure in writing to the indices, we need to clean up + AtomicBoolean partialFailure = new AtomicBoolean(false); + var cleanupListener = listener.delegateResponse((delegate, ex) -> { + if (partialFailure.get()) { + deleteModel(model.getInferenceEntityId(), ActionListener.running(() -> delegate.onFailure(ex))); + } else { + delegate.onFailure(ex); + } + }); return ActionListener.wrap(bulkItemResponses -> { - var inferenceEntityId = model.getConfigurations().getInferenceEntityId(); + var inferenceEntityId = model.getInferenceEntityId(); if (bulkItemResponses.getItems().length == 0) { logger.warn( format("Storing inference endpoint [%s] failed, no items were received from the bulk response", inferenceEntityId) ); - listener.onFailure( + cleanupListener.onFailure( new ElasticsearchStatusException( format( "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service", @@ -707,7 +725,7 @@ private ActionListener getStoreIndexListener( if (failure == null) { if (updateClusterState) { - var storeListener = getStoreMetadataListener(inferenceEntityId, listener); + var storeListener = getStoreMetadataListener(inferenceEntityId, cleanupListener); try { metadataTaskQueue.submitTask( "add model [" + inferenceEntityId + "]", @@ -723,19 +741,22 @@ private ActionListener getStoreIndexListener( storeListener.onFailure(exc); } } else { - listener.onResponse(Boolean.TRUE); + cleanupListener.onResponse(Boolean.TRUE); } return; } - logBulkFailures(model.getConfigurations().getInferenceEntityId(), bulkItemResponses); + for (BulkItemResponse aResponse : bulkItemResponses.getItems()) { + logBulkFailure(inferenceEntityId, aResponse); + partialFailure.compareAndSet(false, aResponse.isFailed() == false); + } if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) { - listener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); + cleanupListener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId)); return; } - listener.onFailure( + cleanupListener.onFailure( new ElasticsearchStatusException( format("Failed to store inference endpoint [%s]", inferenceEntityId), RestStatus.INTERNAL_SERVER_ERROR, @@ -743,9 +764,9 @@ private ActionListener getStoreIndexListener( ) ); }, e -> { - String errorMessage = format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId()); + String errorMessage = format("Failed to store inference endpoint [%s]", model.getInferenceEntityId()); logger.warn(errorMessage, e); - listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); + cleanupListener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e)); }); } @@ -779,18 +800,12 @@ public void onFailure(Exception exc) { }; } - private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) { - for (BulkItemResponse item : bulkResponse.getItems()) { - if (item.isFailed()) { - logger.warn( - format( - "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]", - inferenceEntityId, - item.getIndex(), - item.getFailureMessage() - ) - ); - } + private static void logBulkFailure(String inferenceEntityId, BulkItemResponse item) { + if (item.isFailed()) { + logger.warn( + format("Failed to store inference endpoint [%s] index: [%s]", inferenceEntityId, item.getIndex()), + item.getFailure().getCause() + ); } } @@ -937,6 +952,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T } } + static IndexRequestBuilder createIndexRequestBuilder( + String inferenceId, + String indexName, + ToXContentObject body, + boolean allowOverwriting, + Client client + ) { + try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) { + XContentBuilder source = body.toXContent( + xContentBuilder, + new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())) + ); + + return new IndexRequestBuilder(client).setIndex(indexName) + .setCreate(allowOverwriting == false) + .setId(Model.documentId(inferenceId)) + .setSource(source); + } catch (IOException ex) { + throw new ElasticsearchException( + "Unexpected serialization exception for index [{}] inference ID [{}]", + ex, + indexName, + inferenceId + ); + } + } + private static UnparsedModel modelToUnparsedModel(Model model) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { model.getConfigurations() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index eee8550ec6524..2980621e52c6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -10,9 +10,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.MinimalServiceSettingsTests; @@ -21,15 +26,18 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferenceIndex; +import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; +import org.hamcrest.Matchers; import org.junit.Before; import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -52,7 +60,7 @@ public class ModelRegistryTests extends ESSingleNodeTestCase { @Override protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); } @Before @@ -237,6 +245,96 @@ public void testDuplicateDefaultIds() { ); } + public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(false); + } + + public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() { + storeCorruptedModelThenStoreModel(true); + } + + public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() { + var model = new TestModel( + "model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + PlainActionFuture firstStoreListener = new PlainActionFuture<>(); + registry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS); + firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS); + + assertIndicesContainExpectedDocsCount(model, 2); + + PlainActionFuture secondStoreListener = new PlainActionFuture<>(); + registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS); + + expectThrows(ResourceAlreadyExistsException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS)); + + assertIndicesContainExpectedDocsCount(model, 2); + } + + private void storeCorruptedModelThenStoreModel(boolean storeSecrets) { + var model = new TestModel( + "corrupted-model-id", + TaskType.SPARSE_EMBEDDING, + "foo", + new TestModel.TestServiceSettings(null, null, null, null), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings("secret") + ); + + storeCorruptedModel(model, storeSecrets); + + assertIndicesContainExpectedDocsCount(model, 1); + + PlainActionFuture storeListener = new PlainActionFuture<>(); + registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS); + + expectThrows(ResourceAlreadyExistsException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS)); + + assertIndicesContainExpectedDocsCount(model, 0); + } + + private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) { + SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId())))) + .setSize(2) + .setTrackTotalHits(false) + .request(); + SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS); + try { + assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs)); + } finally { + searchResponse.decRef(); + } + } + + private void storeCorruptedModel(Model model, boolean storeSecrets) { + var listener = new PlainActionFuture(); + + client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add( + ModelRegistry.createIndexRequestBuilder( + model.getInferenceEntityId(), + storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME, + storeSecrets ? model.getSecrets() : model.getConfigurations(), + false, + client() + ) + ) + .execute(listener); + + var bulkResponse = listener.actionGet(TIMEOUT); + if (bulkResponse.hasFailures()) { + fail("Failed to store model: " + bulkResponse.buildFailureMessage()); + } + } + public static void assertStoreModel(ModelRegistry registry, Model model) { PlainActionFuture storeListener = new PlainActionFuture<>(); registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);