Skip to content

Commit 6d3b398

Browse files
committed
add UTs for undeploy stale model index fix
Added UTs for the 2 scenarios 1. Check that the bulk operation occured when no nodes are returned from the Undeploy response is , 2. Check that the bulk operation did not occur when there are nodes that have found the model within their cache. Signed-off-by: Brian Flores <[email protected]>
1 parent 9104cb8 commit 6d3b398

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ private void bulkSetModelIndexToUndeploy(
194194
}
195195

196196
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
197-
log.info("No nodes service: {}", modelIds.toString());
197+
log.info("No nodes service: {}", Arrays.toString(modelIds));
198198

199199
client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> {
200200
log.debug("Successfully set modelIds to UNDEPLOY in index");

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
import static org.mockito.Mockito.doReturn;
1414
import static org.mockito.Mockito.doThrow;
1515
import static org.mockito.Mockito.mock;
16+
import static org.mockito.Mockito.never;
1617
import static org.mockito.Mockito.spy;
1718
import static org.mockito.Mockito.verify;
1819
import static org.mockito.Mockito.when;
20+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1921
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
2022

2123
import java.io.IOException;
2224
import java.util.ArrayList;
2325
import java.util.List;
26+
import java.util.Map;
2427

2528
import org.junit.Before;
2629
import org.junit.Rule;
@@ -29,7 +32,10 @@
2932
import org.mockito.Mock;
3033
import org.mockito.MockitoAnnotations;
3134
import org.opensearch.action.FailedNodeException;
35+
import org.opensearch.action.bulk.BulkRequest;
36+
import org.opensearch.action.bulk.BulkResponse;
3237
import org.opensearch.action.support.ActionFilters;
38+
import org.opensearch.action.update.UpdateRequest;
3339
import org.opensearch.client.Client;
3440
import org.opensearch.cluster.ClusterName;
3541
import org.opensearch.cluster.service.ClusterService;
@@ -42,6 +48,7 @@
4248
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
4349
import org.opensearch.ml.common.FunctionName;
4450
import org.opensearch.ml.common.MLModel;
51+
import org.opensearch.ml.common.model.MLModelState;
4552
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
4653
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
4754
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
@@ -164,6 +171,129 @@ public void setup() throws IOException {
164171
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
165172
}
166173

174+
public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
175+
String modelId = "someModelId";
176+
MLModel mlModel = MLModel
177+
.builder()
178+
.user(User.parse(USER_STRING))
179+
.modelGroupId("111")
180+
.version("111")
181+
.name("Test Model")
182+
.modelId(modelId)
183+
.algorithm(FunctionName.BATCH_RCF)
184+
.content("content")
185+
.totalChunks(2)
186+
.isHidden(true)
187+
.build();
188+
189+
doAnswer(invocation -> {
190+
ActionListener<MLModel> listener = invocation.getArgument(3);
191+
listener.onResponse(mlModel);
192+
return null;
193+
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
194+
195+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
196+
197+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
198+
List<FailedNodeException> failuresList = new ArrayList<>();
199+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
200+
201+
// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
202+
doAnswer(invocation -> {
203+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
204+
listener.onResponse(nodesResponse);
205+
return null;
206+
}).when(client).execute(any(), any(), isA(ActionListener.class));
207+
208+
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
209+
210+
// mock the bulk response that can be captured for inspecting the contents of the write to index
211+
doAnswer(invocation -> {
212+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
213+
BulkResponse bulkResponse = mock(BulkResponse.class);
214+
when(bulkResponse.hasFailures()).thenReturn(false);
215+
listener.onResponse(bulkResponse);
216+
return null;
217+
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
218+
219+
String[] modelIds = new String[] { modelId };
220+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
221+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
222+
223+
transportUndeployModelsAction.doExecute(task, request, actionListener);
224+
225+
BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
226+
assertEquals(1, capturedBulkRequest.numberOfActions());
227+
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);
228+
229+
@SuppressWarnings("unchecked")
230+
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
231+
String modelIdFromBulkRequest = updateRequest.id();
232+
String indexNameFromBulkRequest = updateRequest.index();
233+
234+
assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
235+
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);
236+
237+
assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
238+
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
239+
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
240+
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
241+
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));
242+
243+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
244+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
245+
}
246+
247+
public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
248+
String modelId = "someModelId";
249+
MLModel mlModel = MLModel
250+
.builder()
251+
.user(User.parse(USER_STRING))
252+
.modelGroupId("111")
253+
.version("111")
254+
.name("Test Model")
255+
.modelId(modelId)
256+
.algorithm(FunctionName.BATCH_RCF)
257+
.content("content")
258+
.totalChunks(2)
259+
.isHidden(true)
260+
.build();
261+
262+
doAnswer(invocation -> {
263+
ActionListener<MLModel> listener = invocation.getArgument(3);
264+
listener.onResponse(mlModel);
265+
return null;
266+
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
267+
268+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
269+
270+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
271+
responseList.add(mock(MLUndeployModelNodeResponse.class));
272+
responseList.add(mock(MLUndeployModelNodeResponse.class));
273+
List<FailedNodeException> failuresList = new ArrayList<>();
274+
failuresList.add(mock(FailedNodeException.class));
275+
failuresList.add(mock(FailedNodeException.class));
276+
277+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
278+
279+
// Send back a response with nodes associated to the model
280+
doAnswer(invocation -> {
281+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
282+
listener.onResponse(nodesResponse);
283+
return null;
284+
}).when(client).execute(any(), any(), isA(ActionListener.class));
285+
286+
String[] modelIds = new String[] { modelId };
287+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
288+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
289+
290+
transportUndeployModelsAction.doExecute(task, request, actionListener);
291+
292+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
293+
// Check that no bulk write occurred Since there were nodes servicing the model
294+
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
295+
}
296+
167297
public void testHiddenModelSuccess() {
168298
MLModel mlModel = MLModel
169299
.builder()
@@ -186,16 +316,28 @@ public void testHiddenModelSuccess() {
186316
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
187317
List<FailedNodeException> failuresList = new ArrayList<>();
188318
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
319+
189320
doAnswer(invocation -> {
190321
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
191322
listener.onResponse(response);
192323
return null;
193324
}).when(client).execute(any(), any(), isA(ActionListener.class));
194325

326+
// Mock the client.bulk call
327+
doAnswer(invocation -> {
328+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
329+
BulkResponse bulkResponse = mock(BulkResponse.class);
330+
when(bulkResponse.hasFailures()).thenReturn(false);
331+
listener.onResponse(bulkResponse);
332+
return null;
333+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
334+
195335
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
196336
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
197337
transportUndeployModelsAction.doExecute(task, request, actionListener);
338+
198339
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
340+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
199341
}
200342

201343
public void testHiddenModelPermissionError() {
@@ -249,9 +391,19 @@ public void testDoExecute() {
249391
listener.onResponse(response);
250392
return null;
251393
}).when(client).execute(any(), any(), isA(ActionListener.class));
394+
// Mock the client.bulk call
395+
doAnswer(invocation -> {
396+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
397+
BulkResponse bulkResponse = mock(BulkResponse.class);
398+
when(bulkResponse.hasFailures()).thenReturn(false);
399+
listener.onResponse(bulkResponse);
400+
return null;
401+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
402+
252403
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
253404
transportUndeployModelsAction.doExecute(task, request, actionListener);
254405
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
406+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
255407
}
256408

257409
public void testDoExecute_modelAccessControl_notEnabled() {

0 commit comments

Comments
 (0)