Skip to content

Commit 8a69617

Browse files
committed
refactors undeploy models client with sdkClient bulk op
1 parent 8fca9cc commit 8a69617

File tree

2 files changed

+112
-55
lines changed

2 files changed

+112
-55
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010

1111
import java.time.Instant;
1212
import java.util.Arrays;
13+
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
1516
import java.util.stream.Collectors;
1617

1718
import org.opensearch.ExceptionsHelper;
1819
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.action.ActionRequest;
20-
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkItemResponse;
2122
import org.opensearch.action.bulk.BulkResponse;
2223
import org.opensearch.action.search.SearchRequest;
2324
import org.opensearch.action.search.SearchResponse;
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.action.support.WriteRequest;
27-
import org.opensearch.action.update.UpdateRequest;
2828
import org.opensearch.client.Client;
2929
import org.opensearch.cluster.service.ClusterService;
3030
import org.opensearch.common.inject.Inject;
@@ -56,8 +56,10 @@
5656
import org.opensearch.ml.task.MLTaskManager;
5757
import org.opensearch.ml.utils.RestActionUtils;
5858
import org.opensearch.ml.utils.TenantAwareHelper;
59+
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
5960
import org.opensearch.remote.metadata.client.SdkClient;
6061
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
62+
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
6163
import org.opensearch.remote.metadata.common.SdkClientUtils;
6264
import org.opensearch.search.SearchHit;
6365
import org.opensearch.search.builder.SearchSourceBuilder;
@@ -66,7 +68,6 @@
6668
import org.opensearch.transport.TransportService;
6769

6870
import com.google.common.annotations.VisibleForTesting;
69-
import com.google.common.collect.ImmutableMap;
7071

7172
import lombok.extern.log4j.Log4j2;
7273

@@ -213,56 +214,89 @@ private void undeployModels(
213214
return modelCacheMissForModelIds;
214215
});
215216
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
216-
bulkSetModelIndexToUndeploy(modelIds, listener, response);
217+
bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response);
217218
return;
218219
}
219220
listener.onResponse(new MLUndeployModelsResponse(response));
220221
}, listener::onFailure));
221222
}
222223

223224
private void bulkSetModelIndexToUndeploy(
224-
String[] modelIds,
225-
ActionListener<MLUndeployModelsResponse> listener,
226-
MLUndeployModelNodesResponse response
225+
String[] modelIds,
226+
String tenantId,
227+
ActionListener<MLUndeployModelsResponse> listener,
228+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse
227229
) {
228-
BulkRequest bulkUpdateRequest = new BulkRequest();
229-
for (String modelId : modelIds) {
230-
UpdateRequest updateRequest = new UpdateRequest();
230+
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();
231231

232-
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
233-
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
234-
235-
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
236-
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
232+
for (String modelId : modelIds) {
237233

238-
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
239-
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
240-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
241-
bulkUpdateRequest.add(updateRequest);
234+
Map<String, Object> updateDocument = new HashMap<>();
235+
236+
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
237+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
238+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
239+
updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
240+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
241+
242+
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
243+
.builder()
244+
.id(modelId)
245+
.tenantId(tenantId)
246+
.dataObject(updateDocument)
247+
.build();
248+
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
242249
}
243250

244-
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
245251
log.info("No nodes running these models: {}", Arrays.toString(modelIds));
246252

247253
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
248254
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
249255
.runBefore(listener, () -> threadContext.restore());
256+
250257
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
251-
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
252-
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
258+
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse));
253259
}, e -> {
254260
String modelsNotFoundMessage = String
255-
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
261+
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
256262
log.error(modelsNotFoundMessage, e);
257263

258264
OpenSearchStatusException exception = new OpenSearchStatusException(
259-
modelsNotFoundMessage + e.getMessage(),
260-
RestStatus.INTERNAL_SERVER_ERROR
265+
modelsNotFoundMessage + e.getMessage(),
266+
RestStatus.INTERNAL_SERVER_ERROR
261267
);
262268
listenerWithContextRestoration.onFailure(exception);
263269
});
264270

265-
client.bulk(bulkUpdateRequest, bulkResponseListener);
271+
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
272+
if (exception != null) {
273+
Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class);
274+
bulkResponseListener.onFailure(cause);
275+
return;
276+
}
277+
278+
try {
279+
BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser());
280+
log
281+
.info(
282+
"Executed {} bulk operations with {} failures, Took: {}",
283+
bulkResponse.getItems().length,
284+
bulkResponse.hasFailures()
285+
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
286+
: 0,
287+
bulkResponse.getTook()
288+
);
289+
List<String> unemployedModelIds = Arrays.stream(bulkResponse.getItems())
290+
.filter(bulkItemResponse -> !bulkItemResponse.isFailed())
291+
.map(BulkItemResponse::getId)
292+
.collect(Collectors.toList());
293+
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(unemployedModelIds.toArray()));
294+
295+
bulkResponseListener.onResponse(bulkResponse);
296+
} catch (Exception e) {
297+
bulkResponseListener.onFailure(e);
298+
}
299+
});
266300
} catch (Exception e) {
267301
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
268302
listener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,22 @@
77

88
package org.opensearch.ml.action.undeploy;
99

10-
import static org.mockito.ArgumentMatchers.any;
11-
import static org.mockito.ArgumentMatchers.isA;
12-
import static org.mockito.Mockito.doAnswer;
13-
import static org.mockito.Mockito.doReturn;
14-
import static org.mockito.Mockito.doThrow;
15-
import static org.mockito.Mockito.mock;
16-
import static org.mockito.Mockito.never;
17-
import static org.mockito.Mockito.spy;
18-
import static org.mockito.Mockito.verify;
19-
import static org.mockito.Mockito.when;
20-
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
21-
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
22-
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
23-
24-
import java.io.IOException;
25-
import java.util.ArrayList;
26-
import java.util.HashMap;
27-
import java.util.List;
28-
import java.util.Map;
29-
3010
import org.junit.Before;
3111
import org.junit.Rule;
3212
import org.junit.rules.ExpectedException;
3313
import org.mockito.ArgumentCaptor;
3414
import org.mockito.Mock;
15+
import org.mockito.Mockito;
3516
import org.mockito.MockitoAnnotations;
17+
import org.opensearch.action.DocWriteRequest;
18+
import org.opensearch.action.DocWriteResponse;
3619
import org.opensearch.action.FailedNodeException;
20+
import org.opensearch.action.bulk.BulkItemResponse;
3721
import org.opensearch.action.bulk.BulkRequest;
3822
import org.opensearch.action.bulk.BulkResponse;
3923
import org.opensearch.action.support.ActionFilters;
4024
import org.opensearch.action.update.UpdateRequest;
25+
import org.opensearch.action.update.UpdateResponse;
4126
import org.opensearch.client.Client;
4227
import org.opensearch.cluster.ClusterName;
4328
import org.opensearch.cluster.service.ClusterService;
@@ -46,6 +31,7 @@
4631
import org.opensearch.commons.ConfigConstants;
4732
import org.opensearch.commons.authuser.User;
4833
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.core.index.shard.ShardId;
4935
import org.opensearch.core.xcontent.NamedXContentRegistry;
5036
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
5137
import org.opensearch.ml.common.FunctionName;
@@ -62,11 +48,33 @@
6248
import org.opensearch.ml.task.MLTaskDispatcher;
6349
import org.opensearch.ml.task.MLTaskManager;
6450
import org.opensearch.remote.metadata.client.SdkClient;
51+
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
6552
import org.opensearch.tasks.Task;
6653
import org.opensearch.test.OpenSearchTestCase;
6754
import org.opensearch.threadpool.ThreadPool;
6855
import org.opensearch.transport.TransportService;
6956

57+
import java.io.IOException;
58+
import java.util.ArrayList;
59+
import java.util.Collections;
60+
import java.util.HashMap;
61+
import java.util.List;
62+
import java.util.Map;
63+
64+
import static org.mockito.ArgumentMatchers.any;
65+
import static org.mockito.ArgumentMatchers.isA;
66+
import static org.mockito.Mockito.doAnswer;
67+
import static org.mockito.Mockito.doReturn;
68+
import static org.mockito.Mockito.doThrow;
69+
import static org.mockito.Mockito.mock;
70+
import static org.mockito.Mockito.never;
71+
import static org.mockito.Mockito.spy;
72+
import static org.mockito.Mockito.verify;
73+
import static org.mockito.Mockito.when;
74+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
75+
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
76+
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
77+
7078
public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
7179

7280
@Mock
@@ -135,6 +143,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
135143
public void setup() throws IOException {
136144
MockitoAnnotations.openMocks(this);
137145
Settings settings = Settings.builder().build();
146+
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));
147+
138148
transportUndeployModelsAction = spy(
139149
new TransportUndeployModelsAction(
140150
transportService,
@@ -217,11 +227,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
217227

218228
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
219229

230+
BulkResponse bulkResponse = getSuccessBulkResponse();
220231
// mock the bulk response that can be captured for inspecting the contents of the write to index
221232
doAnswer(invocation -> {
222233
ActionListener<BulkResponse> listener = invocation.getArgument(1);
223-
BulkResponse bulkResponse = mock(BulkResponse.class);
224-
when(bulkResponse.hasFailures()).thenReturn(false);
225234
listener.onResponse(bulkResponse);
226235
return null;
227236
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
@@ -333,11 +342,10 @@ public void testHiddenModelSuccess() {
333342
return null;
334343
}).when(client).execute(any(), any(), isA(ActionListener.class));
335344

345+
BulkResponse bulkResponse = getSuccessBulkResponse();
336346
// Mock the client.bulk call
337347
doAnswer(invocation -> {
338348
ActionListener<BulkResponse> listener = invocation.getArgument(1);
339-
BulkResponse bulkResponse = mock(BulkResponse.class);
340-
when(bulkResponse.hasFailures()).thenReturn(false);
341349
listener.onResponse(bulkResponse);
342350
return null;
343351
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
@@ -392,9 +400,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
392400
return null;
393401
}).when(client).execute(any(), any(), isA(ActionListener.class));
394402

403+
BulkResponse bulkResponse = getSuccessBulkResponse();
395404
doAnswer(invocation -> {
396405
ActionListener<BulkResponse> listener = invocation.getArgument(1);
397-
listener.onResponse(mock(BulkResponse.class));
406+
listener.onResponse(bulkResponse);
398407
return null;
399408
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
400409

@@ -458,17 +467,18 @@ public void testDoExecute() {
458467
listener.onResponse(response);
459468
return null;
460469
}).when(client).execute(any(), any(), isA(ActionListener.class));
461-
// Mock the client.bulk call
470+
471+
BulkResponse bulkResponse = getSuccessBulkResponse();
462472
doAnswer(invocation -> {
463473
ActionListener<BulkResponse> listener = invocation.getArgument(1);
464-
BulkResponse bulkResponse = mock(BulkResponse.class);
465-
when(bulkResponse.hasFailures()).thenReturn(false);
466474
listener.onResponse(bulkResponse);
467475
return null;
468-
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
476+
}).when(client).bulk(any(), any());
469477

470478
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
471479
transportUndeployModelsAction.doExecute(task, request, actionListener);
480+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
481+
472482
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
473483
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
474484
}
@@ -534,4 +544,17 @@ public void testDoExecute_modelIds_moreThan1() {
534544
MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null);
535545
transportUndeployModelsAction.doExecute(task, request, actionListener);
536546
}
547+
548+
private BulkResponse getSuccessBulkResponse() {
549+
return new BulkResponse(
550+
new BulkItemResponse[]{
551+
new BulkItemResponse(
552+
1,
553+
DocWriteRequest.OpType.UPDATE,
554+
new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED)
555+
)
556+
},
557+
100L
558+
);
559+
}
537560
}

0 commit comments

Comments
 (0)