Skip to content

Commit 82e0805

Browse files
committed
refactors undeploy models client with sdkClient bulk op
Signed-off-by: Brian Flores <[email protected]>
1 parent d6c7983 commit 82e0805

File tree

2 files changed

+114
-58
lines changed

2 files changed

+114
-58
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@
1010

1111
import java.time.Instant;
1212
import java.util.Arrays;
13+
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
1516
import java.util.stream.Collectors;
1617

1718
import org.opensearch.ExceptionsHelper;
1819
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.action.ActionRequest;
20-
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkItemResponse;
2122
import org.opensearch.action.bulk.BulkResponse;
2223
import org.opensearch.action.search.SearchRequest;
2324
import org.opensearch.action.search.SearchResponse;
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.action.support.WriteRequest;
27-
import org.opensearch.action.update.UpdateRequest;
28+
import org.opensearch.client.Client;
2829
import org.opensearch.cluster.service.ClusterService;
2930
import org.opensearch.common.inject.Inject;
3031
import org.opensearch.common.settings.Settings;
@@ -55,18 +56,18 @@
5556
import org.opensearch.ml.task.MLTaskManager;
5657
import org.opensearch.ml.utils.RestActionUtils;
5758
import org.opensearch.ml.utils.TenantAwareHelper;
59+
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
5860
import org.opensearch.remote.metadata.client.SdkClient;
5961
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
62+
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
6063
import org.opensearch.remote.metadata.common.SdkClientUtils;
6164
import org.opensearch.search.SearchHit;
6265
import org.opensearch.search.builder.SearchSourceBuilder;
6366
import org.opensearch.tasks.Task;
6467
import org.opensearch.threadpool.ThreadPool;
6568
import org.opensearch.transport.TransportService;
66-
import org.opensearch.transport.client.Client;
6769

6870
import com.google.common.annotations.VisibleForTesting;
69-
import com.google.common.collect.ImmutableMap;
7071

7172
import lombok.extern.log4j.Log4j2;
7273

@@ -217,8 +218,7 @@ private void undeployModels(
217218
return modelCacheMissForModelIds;
218219
});
219220
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
220-
log.warn("No node found running the model(s): {}", Arrays.toString(modelIds));
221-
bulkSetModelIndexToUndeploy(modelIds, listener, response);
221+
bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response);
222222
return;
223223
}
224224
log.info("Successfully undeployed model(s) from nodes: {}", Arrays.toString(modelIds));
@@ -227,48 +227,81 @@ private void undeployModels(
227227
}
228228

229229
private void bulkSetModelIndexToUndeploy(
230-
String[] modelIds,
231-
ActionListener<MLUndeployModelsResponse> listener,
232-
MLUndeployModelNodesResponse response
230+
String[] modelIds,
231+
String tenantId,
232+
ActionListener<MLUndeployModelsResponse> listener,
233+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse
233234
) {
234-
BulkRequest bulkUpdateRequest = new BulkRequest();
235-
for (String modelId : modelIds) {
236-
UpdateRequest updateRequest = new UpdateRequest();
235+
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();
237236

238-
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
239-
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
240-
241-
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
242-
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
237+
for (String modelId : modelIds) {
243238

244-
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
245-
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
246-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
247-
bulkUpdateRequest.add(updateRequest);
239+
Map<String, Object> updateDocument = new HashMap<>();
240+
241+
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
242+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
243+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
244+
updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
245+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
246+
247+
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
248+
.builder()
249+
.id(modelId)
250+
.tenantId(tenantId)
251+
.dataObject(updateDocument)
252+
.build();
253+
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
248254
}
249255

250-
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
251256
log.info("No nodes running these models: {}", Arrays.toString(modelIds));
252257

253258
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
254259
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
255260
.runBefore(listener, () -> threadContext.restore());
261+
256262
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
257-
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
258-
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
263+
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse));
259264
}, e -> {
260265
String modelsNotFoundMessage = String
261-
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
266+
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
262267
log.error(modelsNotFoundMessage, e);
263268

264269
OpenSearchStatusException exception = new OpenSearchStatusException(
265-
modelsNotFoundMessage + e.getMessage(),
266-
RestStatus.INTERNAL_SERVER_ERROR
270+
modelsNotFoundMessage + e.getMessage(),
271+
RestStatus.INTERNAL_SERVER_ERROR
267272
);
268273
listenerWithContextRestoration.onFailure(exception);
269274
});
270275

271-
client.bulk(bulkUpdateRequest, bulkResponseListener);
276+
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
277+
if (exception != null) {
278+
Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class);
279+
bulkResponseListener.onFailure(cause);
280+
return;
281+
}
282+
283+
try {
284+
BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser());
285+
log
286+
.info(
287+
"Executed {} bulk operations with {} failures, Took: {}",
288+
bulkResponse.getItems().length,
289+
bulkResponse.hasFailures()
290+
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
291+
: 0,
292+
bulkResponse.getTook()
293+
);
294+
List<String> unemployedModelIds = Arrays.stream(bulkResponse.getItems())
295+
.filter(bulkItemResponse -> !bulkItemResponse.isFailed())
296+
.map(BulkItemResponse::getId)
297+
.collect(Collectors.toList());
298+
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(unemployedModelIds.toArray()));
299+
300+
bulkResponseListener.onResponse(bulkResponse);
301+
} catch (Exception e) {
302+
bulkResponseListener.onFailure(e);
303+
}
304+
});
272305
} catch (Exception e) {
273306
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
274307
listener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,31 @@
55

66
package org.opensearch.ml.action.undeploy;
77

8-
import static org.mockito.ArgumentMatchers.any;
9-
import static org.mockito.ArgumentMatchers.isA;
10-
import static org.mockito.Mockito.doAnswer;
11-
import static org.mockito.Mockito.doReturn;
12-
import static org.mockito.Mockito.doThrow;
13-
import static org.mockito.Mockito.mock;
14-
import static org.mockito.Mockito.never;
15-
import static org.mockito.Mockito.spy;
16-
import static org.mockito.Mockito.verify;
17-
import static org.mockito.Mockito.when;
18-
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
19-
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
20-
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
21-
22-
import java.io.IOException;
23-
import java.util.ArrayList;
24-
import java.util.HashMap;
25-
import java.util.List;
26-
import java.util.Map;
27-
288
import org.junit.Before;
299
import org.junit.Rule;
3010
import org.junit.rules.ExpectedException;
3111
import org.mockito.ArgumentCaptor;
3212
import org.mockito.Mock;
13+
import org.mockito.Mockito;
3314
import org.mockito.MockitoAnnotations;
15+
import org.opensearch.action.DocWriteRequest;
16+
import org.opensearch.action.DocWriteResponse;
3417
import org.opensearch.action.FailedNodeException;
18+
import org.opensearch.action.bulk.BulkItemResponse;
3519
import org.opensearch.action.bulk.BulkRequest;
3620
import org.opensearch.action.bulk.BulkResponse;
3721
import org.opensearch.action.support.ActionFilters;
3822
import org.opensearch.action.update.UpdateRequest;
23+
import org.opensearch.action.update.UpdateResponse;
24+
import org.opensearch.client.Client;
3925
import org.opensearch.cluster.ClusterName;
4026
import org.opensearch.cluster.service.ClusterService;
4127
import org.opensearch.common.settings.Settings;
4228
import org.opensearch.common.util.concurrent.ThreadContext;
4329
import org.opensearch.commons.ConfigConstants;
4430
import org.opensearch.commons.authuser.User;
4531
import org.opensearch.core.action.ActionListener;
32+
import org.opensearch.core.index.shard.ShardId;
4633
import org.opensearch.core.xcontent.NamedXContentRegistry;
4734
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
4835
import org.opensearch.ml.common.FunctionName;
@@ -59,11 +46,32 @@
5946
import org.opensearch.ml.task.MLTaskDispatcher;
6047
import org.opensearch.ml.task.MLTaskManager;
6148
import org.opensearch.remote.metadata.client.SdkClient;
49+
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
6250
import org.opensearch.tasks.Task;
6351
import org.opensearch.test.OpenSearchTestCase;
6452
import org.opensearch.threadpool.ThreadPool;
6553
import org.opensearch.transport.TransportService;
66-
import org.opensearch.transport.client.Client;
54+
55+
import java.io.IOException;
56+
import java.util.ArrayList;
57+
import java.util.Collections;
58+
import java.util.HashMap;
59+
import java.util.List;
60+
import java.util.Map;
61+
62+
import static org.mockito.ArgumentMatchers.any;
63+
import static org.mockito.ArgumentMatchers.isA;
64+
import static org.mockito.Mockito.doAnswer;
65+
import static org.mockito.Mockito.doReturn;
66+
import static org.mockito.Mockito.doThrow;
67+
import static org.mockito.Mockito.mock;
68+
import static org.mockito.Mockito.never;
69+
import static org.mockito.Mockito.spy;
70+
import static org.mockito.Mockito.verify;
71+
import static org.mockito.Mockito.when;
72+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
73+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
74+
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
6775

6876
public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
6977

@@ -133,6 +141,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
133141
public void setup() throws IOException {
134142
MockitoAnnotations.openMocks(this);
135143
Settings settings = Settings.builder().build();
144+
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));
145+
136146
transportUndeployModelsAction = spy(
137147
new TransportUndeployModelsAction(
138148
transportService,
@@ -215,11 +225,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
215225

216226
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
217227

228+
BulkResponse bulkResponse = getSuccessBulkResponse();
218229
// mock the bulk response that can be captured for inspecting the contents of the write to index
219230
doAnswer(invocation -> {
220231
ActionListener<BulkResponse> listener = invocation.getArgument(1);
221-
BulkResponse bulkResponse = mock(BulkResponse.class);
222-
when(bulkResponse.hasFailures()).thenReturn(false);
223232
listener.onResponse(bulkResponse);
224233
return null;
225234
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
@@ -331,11 +340,10 @@ public void testHiddenModelSuccess() {
331340
return null;
332341
}).when(client).execute(any(), any(), isA(ActionListener.class));
333342

343+
BulkResponse bulkResponse = getSuccessBulkResponse();
334344
// Mock the client.bulk call
335345
doAnswer(invocation -> {
336346
ActionListener<BulkResponse> listener = invocation.getArgument(1);
337-
BulkResponse bulkResponse = mock(BulkResponse.class);
338-
when(bulkResponse.hasFailures()).thenReturn(false);
339347
listener.onResponse(bulkResponse);
340348
return null;
341349
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
@@ -390,9 +398,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
390398
return null;
391399
}).when(client).execute(any(), any(), isA(ActionListener.class));
392400

401+
BulkResponse bulkResponse = getSuccessBulkResponse();
393402
doAnswer(invocation -> {
394403
ActionListener<BulkResponse> listener = invocation.getArgument(1);
395-
listener.onResponse(mock(BulkResponse.class));
404+
listener.onResponse(bulkResponse);
396405
return null;
397406
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
398407

@@ -456,17 +465,18 @@ public void testDoExecute() {
456465
listener.onResponse(response);
457466
return null;
458467
}).when(client).execute(any(), any(), isA(ActionListener.class));
459-
// Mock the client.bulk call
468+
469+
BulkResponse bulkResponse = getSuccessBulkResponse();
460470
doAnswer(invocation -> {
461471
ActionListener<BulkResponse> listener = invocation.getArgument(1);
462-
BulkResponse bulkResponse = mock(BulkResponse.class);
463-
when(bulkResponse.hasFailures()).thenReturn(false);
464472
listener.onResponse(bulkResponse);
465473
return null;
466-
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
474+
}).when(client).bulk(any(), any());
467475

468476
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
469477
transportUndeployModelsAction.doExecute(task, request, actionListener);
478+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
479+
470480
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
471481
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
472482
}
@@ -532,4 +542,17 @@ public void testDoExecute_modelIds_moreThan1() {
532542
MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null);
533543
transportUndeployModelsAction.doExecute(task, request, actionListener);
534544
}
545+
546+
private BulkResponse getSuccessBulkResponse() {
547+
return new BulkResponse(
548+
new BulkItemResponse[]{
549+
new BulkItemResponse(
550+
1,
551+
DocWriteRequest.OpType.UPDATE,
552+
new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED)
553+
)
554+
},
555+
100L
556+
);
557+
}
535558
}

0 commit comments

Comments
 (0)