Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -627,8 +628,8 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue
storeModel(model, true, listener, timeout);
}

private void storeModel(Model model, boolean addToClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, addToClusterState, listener, timeout);
private void storeModel(Model model, boolean updateClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, updateClusterState, listener, timeout);

IndexRequest configRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
Expand All @@ -653,7 +654,7 @@ private void storeModel(Model model, boolean addToClusterState, ActionListener<B

private ActionListener<BulkResponse> getStoreIndexListener(
Model model,
boolean addToClusterState,
boolean updateClusterState,
ActionListener<Boolean> listener,
TimeValue timeout
) {
Expand All @@ -680,7 +681,7 @@ private ActionListener<BulkResponse> getStoreIndexListener(
BulkItemResponse.Failure failure = getFirstBulkFailure(bulkItemResponses);

if (failure == null) {
if (addToClusterState) {
if (updateClusterState) {
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
try {
var projectId = clusterService.state().projectState().projectId();
Expand Down Expand Up @@ -777,14 +778,19 @@ public synchronized void removeDefaultConfigs(Set<String> inferenceEntityIds, Ac
}

defaultConfigIds.keySet().removeAll(inferenceEntityIds);
deleteModels(inferenceEntityIds, listener);
// default models are not stored in the cluster state.
deleteModels(inferenceEntityIds, false, listener);
}

public void deleteModel(String inferenceEntityId, ActionListener<Boolean> listener) {
deleteModels(Set.of(inferenceEntityId), listener);
}

public void deleteModels(Set<String> inferenceEntityIds, ActionListener<Boolean> listener) {
deleteModels(inferenceEntityIds, true, listener);
}

private void deleteModels(Set<String> inferenceEntityIds, boolean updateClusterState, ActionListener<Boolean> listener) {
var lockedInferenceIds = new HashSet<>(inferenceEntityIds);
lockedInferenceIds.retainAll(preventDeletionLock);

Expand All @@ -803,16 +809,25 @@ public void deleteModels(Set<String> inferenceEntityIds, ActionListener<Boolean>
}

var request = createDeleteRequest(inferenceEntityIds);
client.execute(DeleteByQueryAction.INSTANCE, request, getDeleteModelClusterStateListener(inferenceEntityIds, listener));
client.execute(
DeleteByQueryAction.INSTANCE,
request,
getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener)
);
}

private ActionListener<BulkByScrollResponse> getDeleteModelClusterStateListener(
Set<String> inferenceEntityIds,
boolean updateClusterState,
ActionListener<Boolean> listener
) {
return new ActionListener<>() {
@Override
public void onResponse(BulkByScrollResponse bulkByScrollResponse) {
if (updateClusterState == false) {
listener.onResponse(Boolean.TRUE);
return;
}
var clusterStateListener = new ActionListener<AcknowledgedResponse>() {
@Override
public void onResponse(AcknowledgedResponse acknowledgedResponse) {
Expand Down Expand Up @@ -920,6 +935,11 @@ public void clusterChanged(ClusterChangedEvent event) {
return;
}

// wait for the cluster state to be recovered
if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) {
return;
}

if (event.state().metadata().projects().size() > 1) {
// TODO: Add support to handle multi-projects
return;
Expand Down