Skip to content
Closed

2.19 #4078

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand Down Expand Up @@ -56,8 +56,10 @@
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -66,7 +68,6 @@
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -213,7 +214,13 @@ private void undeployModels(
return modelCacheMissForModelIds;
});
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
bulkSetModelIndexToUndeploy(modelIds, listener, response);
log
.warn(
"Model undeployment fallback: No active nodes found for models {}."
+ " Proceeding with manual index update to UNDEPLOY state.",
Arrays.toString(modelIds)
);
bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response);
return;
}
listener.onResponse(new MLUndeployModelsResponse(response));
Expand All @@ -222,34 +229,39 @@ private void undeployModels(

private void bulkSetModelIndexToUndeploy(
String[] modelIds,
String tenantId,
ActionListener<MLUndeployModelsResponse> listener,
MLUndeployModelNodesResponse response
MLUndeployModelNodesResponse mlUndeployModelNodesResponse
) {
BulkRequest bulkUpdateRequest = new BulkRequest();
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();

for (String modelId : modelIds) {
UpdateRequest updateRequest = new UpdateRequest();

ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
Map<String, Object> updateDocument = new HashMap<>();

builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);

builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
bulkUpdateRequest.add(updateRequest);
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
.builder()
.id(modelId)
.tenantId(tenantId)
.dataObject(updateDocument)
.build();
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
}

bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
log.info("No nodes running these models: {}", Arrays.toString(modelIds));

try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
.runBefore(listener, () -> threadContext.restore());

ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse));
}, e -> {
String modelsNotFoundMessage = String
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
Expand All @@ -262,7 +274,40 @@ private void bulkSetModelIndexToUndeploy(
listenerWithContextRestoration.onFailure(exception);
});

client.bulk(bulkUpdateRequest, bulkResponseListener);
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
if (exception != null) {
Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class);
bulkResponseListener.onFailure(cause);
return;
}

try {
BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser());
log
.info(
"Executed {} bulk operations with {} failures, Took: {}",
bulkResponse.getItems().length,
bulkResponse.hasFailures()
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
: 0,
bulkResponse.getTook()
);
List<String> unemployedModelIds = Arrays
.stream(bulkResponse.getItems())
.filter(bulkItemResponse -> !bulkItemResponse.isFailed())
.map(BulkItemResponse::getId)
.collect(Collectors.toList());
log
.debug(
"Successfully set the following modelId(s) to UNDEPLOY in index: {}",
Arrays.toString(unemployedModelIds.toArray())
);

bulkResponseListener.onResponse(bulkResponse);
} catch (Exception e) {
bulkResponseListener.onFailure(e);
}
});
} catch (Exception e) {
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,12 +33,17 @@
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -46,6 +52,7 @@
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -62,6 +69,7 @@
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -135,6 +143,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
Settings settings = Settings.builder().build();
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));

transportUndeployModelsAction = spy(
new TransportUndeployModelsAction(
transportService,
Expand Down Expand Up @@ -217,11 +227,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {

ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);

BulkResponse bulkResponse = getSuccessBulkResponse();
// mock the bulk response that can be captured for inspecting the contents of the write to index
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
Expand Down Expand Up @@ -333,11 +342,10 @@ public void testHiddenModelSuccess() {
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

BulkResponse bulkResponse = getSuccessBulkResponse();
// Mock the client.bulk call
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
Expand Down Expand Up @@ -392,9 +400,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

BulkResponse bulkResponse = getSuccessBulkResponse();
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
listener.onResponse(mock(BulkResponse.class));
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

Expand Down Expand Up @@ -458,17 +467,18 @@ public void testDoExecute() {
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));
// Mock the client.bulk call

BulkResponse bulkResponse = getSuccessBulkResponse();
doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
BulkResponse bulkResponse = mock(BulkResponse.class);
when(bulkResponse.hasFailures()).thenReturn(false);
listener.onResponse(bulkResponse);
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}).when(client).bulk(any(), any());

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));

verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}
Expand Down Expand Up @@ -534,4 +544,16 @@ public void testDoExecute_modelIds_moreThan1() {
MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null);
transportUndeployModelsAction.doExecute(task, request, actionListener);
}

private BulkResponse getSuccessBulkResponse() {
return new BulkResponse(
new BulkItemResponse[] {
new BulkItemResponse(
1,
DocWriteRequest.OpType.UPDATE,
new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED)
) },
100L
);
}
}
Loading