diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index a24cc9b015..df0d4beb79 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -10,6 +10,7 @@ import java.time.Instant; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -17,14 +18,13 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -56,8 +56,10 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.BulkDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -66,7 +68,6 @@ import org.opensearch.transport.TransportService; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -213,7 +214,13 @@ private void undeployModels( return modelCacheMissForModelIds; }); if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) { - bulkSetModelIndexToUndeploy(modelIds, listener, response); + log + .warn( + "Model undeployment fallback: No active nodes found for models {}." + + " Proceeding with manual index update to UNDEPLOY state.", + Arrays.toString(modelIds) + ); + bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response); return; } listener.onResponse(new MLUndeployModelsResponse(response)); @@ -222,34 +229,39 @@ private void undeployModels( private void bulkSetModelIndexToUndeploy( String[] modelIds, + String tenantId, ActionListener listener, - MLUndeployModelNodesResponse response + MLUndeployModelNodesResponse mlUndeployModelNodesResponse ) { - BulkRequest bulkUpdateRequest = new BulkRequest(); + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build(); + for (String modelId : modelIds) { - UpdateRequest updateRequest = new UpdateRequest(); - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name()); + Map updateDocument = new HashMap<>(); - builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of()); - builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name()); + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of()); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); - builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build()); - bulkUpdateRequest.add(updateRequest); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(modelId) + .tenantId(tenantId) + .dataObject(updateDocument) + .build(); + bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); } - bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); log.info("No nodes running these models: {}", Arrays.toString(modelIds)); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener listenerWithContextRestoration = ActionListener .runBefore(listener, () -> threadContext.restore()); + ActionListener bulkResponseListener = ActionListener.wrap(br -> { - log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds)); - listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response)); + listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse)); }, e -> { String modelsNotFoundMessage = String .format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds)); @@ -262,7 +274,40 @@ private void bulkSetModelIndexToUndeploy( listenerWithContextRestoration.onFailure(exception); }); - client.bulk(bulkUpdateRequest, bulkResponseListener); + sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> { + if (exception != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class); + bulkResponseListener.onFailure(cause); + return; + } + + try { + BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser()); + log + .info( + "Executed {} bulk operations with {} failures, Took: {}", + bulkResponse.getItems().length, + bulkResponse.hasFailures() + ? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count() + : 0, + bulkResponse.getTook() + ); + List unemployedModelIds = Arrays + .stream(bulkResponse.getItems()) + .filter(bulkItemResponse -> !bulkItemResponse.isFailed()) + .map(BulkItemResponse::getId) + .collect(Collectors.toList()); + log + .debug( + "Successfully set the following modelId(s) to UNDEPLOY in index: {}", + Arrays.toString(unemployedModelIds.toArray()) + ); + + bulkResponseListener.onResponse(bulkResponse); + } catch (Exception e) { + bulkResponseListener.onFailure(e); + } + }); } catch (Exception e) { log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e); listener.onFailure(e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index ef088ab2b3..da7c791c52 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,12 +33,17 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.service.ClusterService; @@ -46,6 +52,7 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -62,6 +69,7 @@ import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -135,6 +143,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); + transportUndeployModelsAction = spy( new TransportUndeployModelsAction( transportService, @@ -217,11 +227,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() { ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + BulkResponse bulkResponse = getSuccessBulkResponse(); // mock the bulk response that can be captured for inspecting the contents of the write to index doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - BulkResponse bulkResponse = mock(BulkResponse.class); - when(bulkResponse.hasFailures()).thenReturn(false); listener.onResponse(bulkResponse); return null; }).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class)); @@ -333,11 +342,10 @@ public void testHiddenModelSuccess() { return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + BulkResponse bulkResponse = getSuccessBulkResponse(); // Mock the client.bulk call doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - BulkResponse bulkResponse = mock(BulkResponse.class); - when(bulkResponse.hasFailures()).thenReturn(false); listener.onResponse(bulkResponse); return null; }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); @@ -392,9 +400,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() { return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + BulkResponse bulkResponse = getSuccessBulkResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onResponse(mock(BulkResponse.class)); + listener.onResponse(bulkResponse); return null; }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); @@ -458,17 +467,18 @@ public void testDoExecute() { listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); - // Mock the client.bulk call + + BulkResponse bulkResponse = getSuccessBulkResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - BulkResponse bulkResponse = mock(BulkResponse.class); - when(bulkResponse.hasFailures()).thenReturn(false); listener.onResponse(bulkResponse); return null; - }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + }).when(client).bulk(any(), any()); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } @@ -534,4 +544,16 @@ public void testDoExecute_modelIds_moreThan1() { MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); } + + private BulkResponse getSuccessBulkResponse() { + return new BulkResponse( + new BulkItemResponse[] { + new BulkItemResponse( + 1, + DocWriteRequest.OpType.UPDATE, + new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED) + ) }, + 100L + ); + } }