Skip to content

Commit 190b2df

Browse files
authored
Refactor client.bulk to sdkClient.bulkDataObjectAsync when undeploying models with empty model cache edgecase (opensearch-project#4075)
* refactors undeploy models client with sdkClient bulk op Signed-off-by: Brian Flores <[email protected]> * apply spotless Signed-off-by: Brian Flores <[email protected]> --------- Signed-off-by: Brian Flores <[email protected]>
1 parent 7ade595 commit 190b2df

File tree

2 files changed

+90
-29
lines changed

2 files changed

+90
-29
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010

1111
import java.time.Instant;
1212
import java.util.Arrays;
13+
import java.util.HashMap;
1314
import java.util.List;
1415
import java.util.Map;
1516
import java.util.stream.Collectors;
1617

1718
import org.opensearch.ExceptionsHelper;
1819
import org.opensearch.OpenSearchStatusException;
1920
import org.opensearch.action.ActionRequest;
20-
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkItemResponse;
2122
import org.opensearch.action.bulk.BulkResponse;
2223
import org.opensearch.action.search.SearchRequest;
2324
import org.opensearch.action.search.SearchResponse;
2425
import org.opensearch.action.support.ActionFilters;
2526
import org.opensearch.action.support.HandledTransportAction;
2627
import org.opensearch.action.support.WriteRequest;
27-
import org.opensearch.action.update.UpdateRequest;
2828
import org.opensearch.cluster.service.ClusterService;
2929
import org.opensearch.common.inject.Inject;
3030
import org.opensearch.common.settings.Settings;
@@ -55,8 +55,10 @@
5555
import org.opensearch.ml.task.MLTaskManager;
5656
import org.opensearch.ml.utils.RestActionUtils;
5757
import org.opensearch.ml.utils.TenantAwareHelper;
58+
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
5859
import org.opensearch.remote.metadata.client.SdkClient;
5960
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
61+
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
6062
import org.opensearch.remote.metadata.common.SdkClientUtils;
6163
import org.opensearch.search.SearchHit;
6264
import org.opensearch.search.builder.SearchSourceBuilder;
@@ -66,7 +68,6 @@
6668
import org.opensearch.transport.client.Client;
6769

6870
import com.google.common.annotations.VisibleForTesting;
69-
import com.google.common.collect.ImmutableMap;
7071

7172
import lombok.extern.log4j.Log4j2;
7273

@@ -217,8 +218,8 @@ private void undeployModels(
217218
return modelCacheMissForModelIds;
218219
});
219220
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
220-
log.warn("No node found running the model(s): {}", Arrays.toString(modelIds));
221-
bulkSetModelIndexToUndeploy(modelIds, listener, response);
221+
log.warn("No nodes service these models, performing manual `UNDEPLOY` write to model index");
222+
bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, response);
222223
return;
223224
}
224225
log.info("Successfully undeployed model(s) from nodes: {}", Arrays.toString(modelIds));
@@ -228,34 +229,39 @@ private void undeployModels(
228229

229230
private void bulkSetModelIndexToUndeploy(
230231
String[] modelIds,
232+
String tenantId,
231233
ActionListener<MLUndeployModelsResponse> listener,
232-
MLUndeployModelNodesResponse response
234+
MLUndeployModelNodesResponse mlUndeployModelNodesResponse
233235
) {
234-
BulkRequest bulkUpdateRequest = new BulkRequest();
236+
BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build();
237+
235238
for (String modelId : modelIds) {
236-
UpdateRequest updateRequest = new UpdateRequest();
237239

238-
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
239-
builder.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
240+
Map<String, Object> updateDocument = new HashMap<>();
240241

241-
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
242-
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
242+
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED.name());
243+
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, List.of());
244+
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
245+
updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
246+
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
243247

244-
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
245-
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
246-
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build());
247-
bulkUpdateRequest.add(updateRequest);
248+
UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest
249+
.builder()
250+
.id(modelId)
251+
.tenantId(tenantId)
252+
.dataObject(updateDocument)
253+
.build();
254+
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
248255
}
249256

250-
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
251257
log.info("No nodes running these models: {}", Arrays.toString(modelIds));
252258

253259
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
254260
ActionListener<MLUndeployModelsResponse> listenerWithContextRestoration = ActionListener
255261
.runBefore(listener, () -> threadContext.restore());
262+
256263
ActionListener<BulkResponse> bulkResponseListener = ActionListener.wrap(br -> {
257-
log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds));
258-
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(response));
264+
listenerWithContextRestoration.onResponse(new MLUndeployModelsResponse(mlUndeployModelNodesResponse));
259265
}, e -> {
260266
String modelsNotFoundMessage = String
261267
.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
@@ -268,7 +274,40 @@ private void bulkSetModelIndexToUndeploy(
268274
listenerWithContextRestoration.onFailure(exception);
269275
});
270276

271-
client.bulk(bulkUpdateRequest, bulkResponseListener);
277+
sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
278+
if (exception != null) {
279+
Exception cause = SdkClientUtils.unwrapAndConvertToException(exception, OpenSearchStatusException.class);
280+
bulkResponseListener.onFailure(cause);
281+
return;
282+
}
283+
284+
try {
285+
BulkResponse bulkResponse = BulkResponse.fromXContent(response.parser());
286+
log
287+
.info(
288+
"Executed {} bulk operations with {} failures, Took: {}",
289+
bulkResponse.getItems().length,
290+
bulkResponse.hasFailures()
291+
? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count()
292+
: 0,
293+
bulkResponse.getTook()
294+
);
295+
List<String> unemployedModelIds = Arrays
296+
.stream(bulkResponse.getItems())
297+
.filter(bulkItemResponse -> !bulkItemResponse.isFailed())
298+
.map(BulkItemResponse::getId)
299+
.collect(Collectors.toList());
300+
log
301+
.debug(
302+
"Successfully set the following modelId(s) to UNDEPLOY in index: {}",
303+
Arrays.toString(unemployedModelIds.toArray())
304+
);
305+
306+
bulkResponseListener.onResponse(bulkResponse);
307+
} catch (Exception e) {
308+
bulkResponseListener.onFailure(e);
309+
}
310+
});
272311
} catch (Exception e) {
273312
log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", Arrays.toString(modelIds), e);
274313
listener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.IOException;
2323
import java.util.ArrayList;
24+
import java.util.Collections;
2425
import java.util.HashMap;
2526
import java.util.List;
2627
import java.util.Map;
@@ -30,19 +31,25 @@
3031
import org.junit.rules.ExpectedException;
3132
import org.mockito.ArgumentCaptor;
3233
import org.mockito.Mock;
34+
import org.mockito.Mockito;
3335
import org.mockito.MockitoAnnotations;
36+
import org.opensearch.action.DocWriteRequest;
37+
import org.opensearch.action.DocWriteResponse;
3438
import org.opensearch.action.FailedNodeException;
39+
import org.opensearch.action.bulk.BulkItemResponse;
3540
import org.opensearch.action.bulk.BulkRequest;
3641
import org.opensearch.action.bulk.BulkResponse;
3742
import org.opensearch.action.support.ActionFilters;
3843
import org.opensearch.action.update.UpdateRequest;
44+
import org.opensearch.action.update.UpdateResponse;
3945
import org.opensearch.cluster.ClusterName;
4046
import org.opensearch.cluster.service.ClusterService;
4147
import org.opensearch.common.settings.Settings;
4248
import org.opensearch.common.util.concurrent.ThreadContext;
4349
import org.opensearch.commons.ConfigConstants;
4450
import org.opensearch.commons.authuser.User;
4551
import org.opensearch.core.action.ActionListener;
52+
import org.opensearch.core.index.shard.ShardId;
4653
import org.opensearch.core.xcontent.NamedXContentRegistry;
4754
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
4855
import org.opensearch.ml.common.FunctionName;
@@ -59,6 +66,7 @@
5966
import org.opensearch.ml.task.MLTaskDispatcher;
6067
import org.opensearch.ml.task.MLTaskManager;
6168
import org.opensearch.remote.metadata.client.SdkClient;
69+
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
6270
import org.opensearch.tasks.Task;
6371
import org.opensearch.test.OpenSearchTestCase;
6472
import org.opensearch.threadpool.ThreadPool;
@@ -133,6 +141,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase {
133141
public void setup() throws IOException {
134142
MockitoAnnotations.openMocks(this);
135143
Settings settings = Settings.builder().build();
144+
sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()));
145+
136146
transportUndeployModelsAction = spy(
137147
new TransportUndeployModelsAction(
138148
transportService,
@@ -215,11 +225,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
215225

216226
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
217227

228+
BulkResponse bulkResponse = getSuccessBulkResponse();
218229
// mock the bulk response that can be captured for inspecting the contents of the write to index
219230
doAnswer(invocation -> {
220231
ActionListener<BulkResponse> listener = invocation.getArgument(1);
221-
BulkResponse bulkResponse = mock(BulkResponse.class);
222-
when(bulkResponse.hasFailures()).thenReturn(false);
223232
listener.onResponse(bulkResponse);
224233
return null;
225234
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
@@ -331,11 +340,10 @@ public void testHiddenModelSuccess() {
331340
return null;
332341
}).when(client).execute(any(), any(), isA(ActionListener.class));
333342

343+
BulkResponse bulkResponse = getSuccessBulkResponse();
334344
// Mock the client.bulk call
335345
doAnswer(invocation -> {
336346
ActionListener<BulkResponse> listener = invocation.getArgument(1);
337-
BulkResponse bulkResponse = mock(BulkResponse.class);
338-
when(bulkResponse.hasFailures()).thenReturn(false);
339347
listener.onResponse(bulkResponse);
340348
return null;
341349
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
@@ -390,9 +398,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
390398
return null;
391399
}).when(client).execute(any(), any(), isA(ActionListener.class));
392400

401+
BulkResponse bulkResponse = getSuccessBulkResponse();
393402
doAnswer(invocation -> {
394403
ActionListener<BulkResponse> listener = invocation.getArgument(1);
395-
listener.onResponse(mock(BulkResponse.class));
404+
listener.onResponse(bulkResponse);
396405
return null;
397406
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
398407

@@ -456,17 +465,18 @@ public void testDoExecute() {
456465
listener.onResponse(response);
457466
return null;
458467
}).when(client).execute(any(), any(), isA(ActionListener.class));
459-
// Mock the client.bulk call
468+
469+
BulkResponse bulkResponse = getSuccessBulkResponse();
460470
doAnswer(invocation -> {
461471
ActionListener<BulkResponse> listener = invocation.getArgument(1);
462-
BulkResponse bulkResponse = mock(BulkResponse.class);
463-
when(bulkResponse.hasFailures()).thenReturn(false);
464472
listener.onResponse(bulkResponse);
465473
return null;
466-
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
474+
}).when(client).bulk(any(), any());
467475

468476
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
469477
transportUndeployModelsAction.doExecute(task, request, actionListener);
478+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
479+
470480
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
471481
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
472482
}
@@ -532,4 +542,16 @@ public void testDoExecute_modelIds_moreThan1() {
532542
MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null);
533543
transportUndeployModelsAction.doExecute(task, request, actionListener);
534544
}
545+
546+
private BulkResponse getSuccessBulkResponse() {
547+
return new BulkResponse(
548+
new BulkItemResponse[] {
549+
new BulkItemResponse(
550+
1,
551+
DocWriteRequest.OpType.UPDATE,
552+
new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED)
553+
) },
554+
100L
555+
);
556+
}
535557
}

0 commit comments

Comments
 (0)