Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/136577.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 136577
summary: Clean up inference indices on failed endpoint creation
area: Machine Learning
type: bug
issues:
- 123726
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.GroupedActionListener;
Expand Down Expand Up @@ -531,11 +532,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi

SubscribableListener.<BulkResponse>newForked((subListener) -> {
// in this block, we try to update the stored model configurations
IndexRequest configRequest = createIndexRequest(
Model.documentId(inferenceEntityId),
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
newModel.getConfigurations(),
true
true,
client
);

ActionListener<BulkResponse> storeConfigListener = subListener.delegateResponse((l, e) -> {
Expand All @@ -544,7 +546,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
l.onFailure(e);
});

client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener);
client.prepareBulk()
.add(configRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(storeConfigListener);

}).<BulkResponse>andThen((subListener, configResponse) -> {
// in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets
Expand All @@ -569,11 +574,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
);
} else {
// Since the model configurations were successfully updated, we can now try to store the new secrets
IndexRequest secretsRequest = createIndexRequest(
Model.documentId(newModel.getConfigurations().getInferenceEntityId()),
var secretsRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceSecretsIndex.INDEX_NAME,
newModel.getSecrets(),
true
true,
client
);

ActionListener<BulkResponse> storeSecretsListener = subListener.delegateResponse((l, e) -> {
Expand All @@ -583,20 +589,22 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
});

client.prepareBulk()
.add(secretsRequest)
.add(secretsRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(storeSecretsListener);
}
}).<BulkResponse>andThen((subListener, secretsResponse) -> {
// in this block, we respond to the success or failure of updating the model secrets
if (secretsResponse.hasFailures()) {
// since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations
IndexRequest configRequest = createIndexRequest(
Model.documentId(inferenceEntityId),
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
existingModel.getConfigurations(),
true
true,
client
);

logger.error(
"Failed to update inference endpoint secrets [{}], attempting rolling back to previous state",
inferenceEntityId
Expand All @@ -608,7 +616,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
l.onFailure(e);
});
client.prepareBulk()
.add(configRequest)
.add(configRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(rollbackConfigListener);
} else {
Expand Down Expand Up @@ -655,24 +663,25 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue

private void storeModel(Model model, boolean updateClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout);

IndexRequest configRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
String inferenceEntityId = model.getConfigurations().getInferenceEntityId();
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
model.getConfigurations(),
false
false,
client
);

IndexRequest secretsRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
var secretsRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceSecretsIndex.INDEX_NAME,
model.getSecrets(),
false
false,
client
);

client.prepareBulk()
.add(configRequest)
.add(secretsRequest)
.add(configRequestBuilder)
.add(secretsRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(bulkResponseActionListener);
}
Expand All @@ -683,8 +692,11 @@ private ActionListener<BulkResponse> getStoreIndexListener(
ActionListener<Boolean> listener,
TimeValue timeout
) {
var cleanupListener = listener.delegateResponse(
(delegate, ex) -> deleteModel(model.getInferenceEntityId(), ActionListener.running(() -> delegate.onFailure(ex)))
);
return ActionListener.wrap(bulkItemResponses -> {
var inferenceEntityId = model.getConfigurations().getInferenceEntityId();
var inferenceEntityId = model.getInferenceEntityId();

if (bulkItemResponses.getItems().length == 0) {
logger.warn(
Expand Down Expand Up @@ -728,22 +740,33 @@ private ActionListener<BulkResponse> getStoreIndexListener(
return;
}

logBulkFailures(model.getConfigurations().getInferenceEntityId(), bulkItemResponses);
boolean anySuccess = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It might be slightly cleaner to move this boolean above the cleanupListener. Then in cleanupListener we can have a check for anySuccess, if there was a success, then do the delete.

Then we don't need to track listenerToInvoke.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done

ActionListener<Boolean> listenerToInvoke;
for (BulkItemResponse aResponse : bulkItemResponses.getItems()) {
logBulkFailure(inferenceEntityId, aResponse);
anySuccess |= aResponse.isFailed() == false;
}
// If there was a partial failure in writing to the indices, we need to clean up
if (anySuccess) {
listenerToInvoke = cleanupListener;
} else {
listenerToInvoke = listener;
}

if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId));
listenerToInvoke.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId));
return;
}

listener.onFailure(
listenerToInvoke.onFailure(
new ElasticsearchStatusException(
format("Failed to store inference endpoint [%s]", inferenceEntityId),
RestStatus.INTERNAL_SERVER_ERROR,
failure.getCause()
)
);
}, e -> {
String errorMessage = format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId());
String errorMessage = format("Failed to store inference endpoint [%s]", model.getInferenceEntityId());
logger.warn(errorMessage, e);
listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e));
});
Expand Down Expand Up @@ -779,18 +802,12 @@ public void onFailure(Exception exc) {
};
}

private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) {
for (BulkItemResponse item : bulkResponse.getItems()) {
if (item.isFailed()) {
logger.warn(
format(
"Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]",
inferenceEntityId,
item.getIndex(),
item.getFailureMessage()
)
);
}
private static void logBulkFailure(String inferenceEntityId, BulkItemResponse item) {
if (item.isFailed()) {
logger.warn(
format("Failed to store inference endpoint [%s] index: [%s]", inferenceEntityId, item.getIndex()),
item.getFailure().getCause()
);
}
}

Expand Down Expand Up @@ -937,6 +954,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T
}
}

static IndexRequestBuilder createIndexRequestBuilder(
String inferenceId,
String indexName,
ToXContentObject body,
boolean allowOverwriting,
Client client
) {
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) {
XContentBuilder source = body.toXContent(
xContentBuilder,
new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString()))
);

return new IndexRequestBuilder(client).setIndex(indexName)
.setCreate(allowOverwriting == false)
.setId(Model.documentId(inferenceId))
.setSource(source);
} catch (IOException ex) {
throw new ElasticsearchException(
"Unexpected serialization exception for index [{}] inference ID [{}]",
ex,
indexName,
inferenceId
);
}
}

private static UnparsedModel modelToUnparsedModel(Model model) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
model.getConfigurations()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.MinimalServiceSettingsTests;
Expand All @@ -21,15 +26,18 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xpack.inference.InferenceIndex;
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.hamcrest.Matchers;
import org.junit.Before;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand All @@ -52,7 +60,7 @@ public class ModelRegistryTests extends ESSingleNodeTestCase {

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
}

@Before
Expand Down Expand Up @@ -237,6 +245,96 @@ public void testDuplicateDefaultIds() {
);
}

public void testStoreModel_DeletesIndexDocs_WhenInferenceIndexDocumentAlreadyExists() {
storeCorruptedModelThenStoreModel(false);
}

public void testStoreModel_DeletesIndexDocs_WhenInferenceSecretsIndexDocumentAlreadyExists() {
storeCorruptedModelThenStoreModel(true);
}

public void testStoreModel_DoesNotDeleteIndexDocs_WhenModelAlreadyExists() {
var model = new TestModel(
"model-id",
TaskType.SPARSE_EMBEDDING,
"foo",
new TestModel.TestServiceSettings(null, null, null, null),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings("secret")
);

PlainActionFuture<Boolean> firstStoreListener = new PlainActionFuture<>();
registry.storeModel(model, firstStoreListener, TimeValue.THIRTY_SECONDS);
firstStoreListener.actionGet(TimeValue.THIRTY_SECONDS);

assertIndicesContainExpectedDocsCount(model, 2);

PlainActionFuture<Boolean> secondStoreListener = new PlainActionFuture<>();
registry.storeModel(model, secondStoreListener, TimeValue.THIRTY_SECONDS);

expectThrows(ResourceAlreadyExistsException.class, () -> secondStoreListener.actionGet(TimeValue.THIRTY_SECONDS));

assertIndicesContainExpectedDocsCount(model, 2);
}

private void storeCorruptedModelThenStoreModel(boolean storeSecrets) {
var model = new TestModel(
"corrupted-model-id",
TaskType.SPARSE_EMBEDDING,
"foo",
new TestModel.TestServiceSettings(null, null, null, null),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings("secret")
);

storeCorruptedModel(model, storeSecrets);

assertIndicesContainExpectedDocsCount(model, 1);

PlainActionFuture<Boolean> storeListener = new PlainActionFuture<>();
registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);

expectThrows(ResourceAlreadyExistsException.class, () -> storeListener.actionGet(TimeValue.THIRTY_SECONDS));

assertIndicesContainExpectedDocsCount(model, 0);
}

private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) {
SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId()))))
.setSize(2)
.setTrackTotalHits(false)
.request();
SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS);
try {
assertThat(searchResponse.getHits().getHits(), Matchers.arrayWithSize(numberOfDocs));
} finally {
searchResponse.decRef();
}
}

private void storeCorruptedModel(Model model, boolean storeSecrets) {
var listener = new PlainActionFuture<BulkResponse>();

client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(
ModelRegistry.createIndexRequestBuilder(
model.getInferenceEntityId(),
storeSecrets ? InferenceSecretsIndex.INDEX_NAME : InferenceIndex.INDEX_NAME,
storeSecrets ? model.getSecrets() : model.getConfigurations(),
false,
client()
)
)
.execute(listener);

var bulkResponse = listener.actionGet(TIMEOUT);
if (bulkResponse.hasFailures()) {
fail("Failed to store model: " + bulkResponse.buildFailureMessage());
}
}

public static void assertStoreModel(ModelRegistry registry, Model model) {
PlainActionFuture<Boolean> storeListener = new PlainActionFuture<>();
registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);
Expand Down