Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/136577.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 136577
summary: Clean up inference indices on failed endpoint creation
area: Machine Learning
type: bug
issues:
- 123726
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.GroupedActionListener;
Expand Down Expand Up @@ -514,11 +515,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi

SubscribableListener.<BulkResponse>newForked((subListener) -> {
// in this block, we try to update the stored model configurations
IndexRequest configRequest = createIndexRequest(
Model.documentId(inferenceEntityId),
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
newModel.getConfigurations(),
true
true,
client
);

ActionListener<BulkResponse> storeConfigListener = subListener.delegateResponse((l, e) -> {
Expand All @@ -527,7 +529,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
l.onFailure(e);
});

client.prepareBulk().add(configRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener);
client.prepareBulk()
.add(configRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(storeConfigListener);

}).<BulkResponse>andThen((subListener, configResponse) -> {
// in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets
Expand All @@ -552,11 +557,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
);
} else {
// Since the model configurations were successfully updated, we can now try to store the new secrets
IndexRequest secretsRequest = createIndexRequest(
Model.documentId(newModel.getConfigurations().getInferenceEntityId()),
var secretsRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceSecretsIndex.INDEX_NAME,
newModel.getSecrets(),
true
true,
client
);

ActionListener<BulkResponse> storeSecretsListener = subListener.delegateResponse((l, e) -> {
Expand All @@ -566,20 +572,22 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
});

client.prepareBulk()
.add(secretsRequest)
.add(secretsRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(storeSecretsListener);
}
}).<BulkResponse>andThen((subListener, secretsResponse) -> {
// in this block, we respond to the success or failure of updating the model secrets
if (secretsResponse.hasFailures()) {
// since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations
IndexRequest configRequest = createIndexRequest(
Model.documentId(inferenceEntityId),
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
existingModel.getConfigurations(),
true
true,
client
);

logger.error(
"Failed to update inference endpoint secrets [{}], attempting rolling back to previous state",
inferenceEntityId
Expand All @@ -591,7 +599,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
l.onFailure(e);
});
client.prepareBulk()
.add(configRequest)
.add(configRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(rollbackConfigListener);
} else {
Expand Down Expand Up @@ -637,24 +645,25 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue

private void storeModel(Model model, boolean updateClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout);

IndexRequest configRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
String inferenceEntityId = model.getConfigurations().getInferenceEntityId();
var configRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceIndex.INDEX_NAME,
model.getConfigurations(),
false
false,
client
);

IndexRequest secretsRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
var secretsRequestBuilder = createIndexRequestBuilder(
inferenceEntityId,
InferenceSecretsIndex.INDEX_NAME,
model.getSecrets(),
false
false,
client
);

client.prepareBulk()
.add(configRequest)
.add(secretsRequest)
.add(configRequestBuilder)
.add(secretsRequestBuilder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.execute(bulkResponseActionListener);
}
Expand All @@ -665,15 +674,24 @@ private ActionListener<BulkResponse> getStoreIndexListener(
ActionListener<Boolean> listener,
TimeValue timeout
) {
// If there was a partial failure in writing to the indices, we need to clean up
AtomicBoolean partialFailure = new AtomicBoolean(false);
var cleanupListener = listener.delegateResponse((delegate, ex) -> {
if (partialFailure.get()) {
deleteModel(model.getInferenceEntityId(), ActionListener.running(() -> delegate.onFailure(ex)));
} else {
delegate.onFailure(ex);
}
});
return ActionListener.wrap(bulkItemResponses -> {
var inferenceEntityId = model.getConfigurations().getInferenceEntityId();
var inferenceEntityId = model.getInferenceEntityId();

if (bulkItemResponses.getItems().length == 0) {
logger.warn(
format("Storing inference endpoint [%s] failed, no items were received from the bulk response", inferenceEntityId)
);

listener.onFailure(
cleanupListener.onFailure(
new ElasticsearchStatusException(
format(
"Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service",
Expand All @@ -689,7 +707,7 @@ private ActionListener<BulkResponse> getStoreIndexListener(

if (failure == null) {
if (updateClusterState) {
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
var storeListener = getStoreMetadataListener(inferenceEntityId, cleanupListener);
try {
metadataTaskQueue.submitTask(
"add model [" + inferenceEntityId + "]",
Expand All @@ -705,29 +723,32 @@ private ActionListener<BulkResponse> getStoreIndexListener(
storeListener.onFailure(exc);
}
} else {
listener.onResponse(Boolean.TRUE);
cleanupListener.onResponse(Boolean.TRUE);
}
return;
}

logBulkFailures(model.getConfigurations().getInferenceEntityId(), bulkItemResponses);
for (BulkItemResponse aResponse : bulkItemResponses.getItems()) {
logBulkFailure(inferenceEntityId, aResponse);
partialFailure.compareAndSet(false, aResponse.isFailed() == false);
}

if (ExceptionsHelper.unwrapCause(failure.getCause()) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId));
cleanupListener.onFailure(new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", inferenceEntityId));
return;
}

listener.onFailure(
cleanupListener.onFailure(
new ElasticsearchStatusException(
format("Failed to store inference endpoint [%s]", inferenceEntityId),
RestStatus.INTERNAL_SERVER_ERROR,
failure.getCause()
)
);
}, e -> {
String errorMessage = format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId());
String errorMessage = format("Failed to store inference endpoint [%s]", model.getInferenceEntityId());
logger.warn(errorMessage, e);
listener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e));
cleanupListener.onFailure(new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, e));
});
}

Expand Down Expand Up @@ -761,18 +782,12 @@ public void onFailure(Exception exc) {
};
}

private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) {
for (BulkItemResponse item : bulkResponse.getItems()) {
if (item.isFailed()) {
logger.warn(
format(
"Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]",
inferenceEntityId,
item.getIndex(),
item.getFailureMessage()
)
);
}
private static void logBulkFailure(String inferenceEntityId, BulkItemResponse item) {
if (item.isFailed()) {
logger.warn(
format("Failed to store inference endpoint [%s] index: [%s]", inferenceEntityId, item.getIndex()),
item.getFailure().getCause()
);
}
}

Expand Down Expand Up @@ -905,6 +920,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T
}
}

static IndexRequestBuilder createIndexRequestBuilder(
String inferenceId,
String indexName,
ToXContentObject body,
boolean allowOverwriting,
Client client
) {
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) {
XContentBuilder source = body.toXContent(
xContentBuilder,
new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString()))
);

return new IndexRequestBuilder(client).setIndex(indexName)
.setCreate(allowOverwriting == false)
.setId(Model.documentId(inferenceId))
.setSource(source);
} catch (IOException ex) {
throw new ElasticsearchException(
"Unexpected serialization exception for index [{}] inference ID [{}]",
ex,
indexName,
inferenceId
);
}
}

private static UnparsedModel modelToUnparsedModel(Model model) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
model.getConfigurations()
Expand Down
Loading
Loading