Skip to content

Commit 37f3f97

Browse files
committed
add UTs for undeploy stale model index fix
Added UTs for the 2 scenarios 1. Check that the bulk operation occured when no nodes are returned from the Undeploy response is , 2. Check that the bulk operation did not occur when there are nodes that have found the model within their cache. Signed-off-by: Brian Flores <[email protected]>
1 parent 08c250a commit 37f3f97

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ private void bulkSetModelIndexToUndeploy(
217217
}
218218

219219
bulkUpdateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
220-
log.info("No nodes service: {}", modelIds.toString());
220+
log.info("No nodes service: {}", Arrays.toString(modelIds));
221221

222222
client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> {
223223
log.debug("Successfully set modelIds to UNDEPLOY in index");

plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
import static org.mockito.Mockito.doReturn;
1414
import static org.mockito.Mockito.doThrow;
1515
import static org.mockito.Mockito.mock;
16+
import static org.mockito.Mockito.never;
1617
import static org.mockito.Mockito.spy;
1718
import static org.mockito.Mockito.verify;
1819
import static org.mockito.Mockito.when;
20+
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
1921
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;
2022

2123
import java.io.IOException;
2224
import java.util.ArrayList;
2325
import java.util.List;
26+
import java.util.Map;
2427

2528
import org.junit.Before;
2629
import org.junit.Rule;
@@ -29,7 +32,10 @@
2932
import org.mockito.Mock;
3033
import org.mockito.MockitoAnnotations;
3134
import org.opensearch.action.FailedNodeException;
35+
import org.opensearch.action.bulk.BulkRequest;
36+
import org.opensearch.action.bulk.BulkResponse;
3237
import org.opensearch.action.support.ActionFilters;
38+
import org.opensearch.action.update.UpdateRequest;
3339
import org.opensearch.client.Client;
3440
import org.opensearch.cluster.ClusterName;
3541
import org.opensearch.cluster.service.ClusterService;
@@ -42,6 +48,7 @@
4248
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
4349
import org.opensearch.ml.common.FunctionName;
4450
import org.opensearch.ml.common.MLModel;
51+
import org.opensearch.ml.common.model.MLModelState;
4552
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
4653
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
4754
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
@@ -172,6 +179,129 @@ public void setup() throws IOException {
172179
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));
173180
}
174181

182+
public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() {
183+
String modelId = "someModelId";
184+
MLModel mlModel = MLModel
185+
.builder()
186+
.user(User.parse(USER_STRING))
187+
.modelGroupId("111")
188+
.version("111")
189+
.name("Test Model")
190+
.modelId(modelId)
191+
.algorithm(FunctionName.BATCH_RCF)
192+
.content("content")
193+
.totalChunks(2)
194+
.isHidden(true)
195+
.build();
196+
197+
doAnswer(invocation -> {
198+
ActionListener<MLModel> listener = invocation.getArgument(3);
199+
listener.onResponse(mlModel);
200+
return null;
201+
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
202+
203+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
204+
205+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
206+
List<FailedNodeException> failuresList = new ArrayList<>();
207+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
208+
209+
// Send back a response with no nodes associated to the model. Thus, will write back to the model index that its UNDEPLOYED
210+
doAnswer(invocation -> {
211+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
212+
listener.onResponse(nodesResponse);
213+
return null;
214+
}).when(client).execute(any(), any(), isA(ActionListener.class));
215+
216+
ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
217+
218+
// mock the bulk response that can be captured for inspecting the contents of the write to index
219+
doAnswer(invocation -> {
220+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
221+
BulkResponse bulkResponse = mock(BulkResponse.class);
222+
when(bulkResponse.hasFailures()).thenReturn(false);
223+
listener.onResponse(bulkResponse);
224+
return null;
225+
}).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class));
226+
227+
String[] modelIds = new String[] { modelId };
228+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
229+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
230+
231+
transportUndeployModelsAction.doExecute(task, request, actionListener);
232+
233+
BulkRequest capturedBulkRequest = bulkRequestCaptor.getValue();
234+
assertEquals(1, capturedBulkRequest.numberOfActions());
235+
UpdateRequest updateRequest = (UpdateRequest) capturedBulkRequest.requests().get(0);
236+
237+
@SuppressWarnings("unchecked")
238+
Map<String, Object> updateDoc = updateRequest.doc().sourceAsMap();
239+
String modelIdFromBulkRequest = updateRequest.id();
240+
String indexNameFromBulkRequest = updateRequest.index();
241+
242+
assertEquals("Check that the write happened at the model index", ML_MODEL_INDEX, indexNameFromBulkRequest);
243+
assertEquals("Check that the result bulk write hit this specific modelId", modelId, modelIdFromBulkRequest);
244+
245+
assertEquals(MLModelState.UNDEPLOYED.name(), updateDoc.get(MLModel.MODEL_STATE_FIELD));
246+
assertEquals(0, updateDoc.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD));
247+
assertEquals(0, updateDoc.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD));
248+
assertEquals(List.of(), updateDoc.get(MLModel.PLANNING_WORKER_NODES_FIELD));
249+
assertTrue(updateDoc.containsKey(MLModel.LAST_UPDATED_TIME_FIELD));
250+
251+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
252+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
253+
}
254+
255+
public void testDoExecute_noBulkRequestFired_WhenSomeNodesServiceModel() {
256+
String modelId = "someModelId";
257+
MLModel mlModel = MLModel
258+
.builder()
259+
.user(User.parse(USER_STRING))
260+
.modelGroupId("111")
261+
.version("111")
262+
.name("Test Model")
263+
.modelId(modelId)
264+
.algorithm(FunctionName.BATCH_RCF)
265+
.content("content")
266+
.totalChunks(2)
267+
.isHidden(true)
268+
.build();
269+
270+
doAnswer(invocation -> {
271+
ActionListener<MLModel> listener = invocation.getArgument(3);
272+
listener.onResponse(mlModel);
273+
return null;
274+
}).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class));
275+
276+
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
277+
278+
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
279+
responseList.add(mock(MLUndeployModelNodeResponse.class));
280+
responseList.add(mock(MLUndeployModelNodeResponse.class));
281+
List<FailedNodeException> failuresList = new ArrayList<>();
282+
failuresList.add(mock(FailedNodeException.class));
283+
failuresList.add(mock(FailedNodeException.class));
284+
285+
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
286+
287+
// Send back a response with nodes associated to the model
288+
doAnswer(invocation -> {
289+
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
290+
listener.onResponse(nodesResponse);
291+
return null;
292+
}).when(client).execute(any(), any(), isA(ActionListener.class));
293+
294+
String[] modelIds = new String[] { modelId };
295+
String[] nodeIds = new String[] { "test_node_id1", "test_node_id2" };
296+
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds);
297+
298+
transportUndeployModelsAction.doExecute(task, request, actionListener);
299+
300+
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
301+
// Check that no bulk write occurred Since there were nodes servicing the model
302+
verify(client, never()).bulk(any(BulkRequest.class), any(ActionListener.class));
303+
}
304+
175305
public void testHiddenModelSuccess() {
176306
MLModel mlModel = MLModel
177307
.builder()
@@ -194,16 +324,28 @@ public void testHiddenModelSuccess() {
194324
List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();
195325
List<FailedNodeException> failuresList = new ArrayList<>();
196326
MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);
327+
197328
doAnswer(invocation -> {
198329
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
199330
listener.onResponse(response);
200331
return null;
201332
}).when(client).execute(any(), any(), isA(ActionListener.class));
202333

334+
// Mock the client.bulk call
335+
doAnswer(invocation -> {
336+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
337+
BulkResponse bulkResponse = mock(BulkResponse.class);
338+
when(bulkResponse.hasFailures()).thenReturn(false);
339+
listener.onResponse(bulkResponse);
340+
return null;
341+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
342+
203343
doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);
204344
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
205345
transportUndeployModelsAction.doExecute(task, request, actionListener);
346+
206347
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
348+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
207349
}
208350

209351
public void testHiddenModelPermissionError() {
@@ -257,9 +399,19 @@ public void testDoExecute() {
257399
listener.onResponse(response);
258400
return null;
259401
}).when(client).execute(any(), any(), isA(ActionListener.class));
402+
// Mock the client.bulk call
403+
doAnswer(invocation -> {
404+
ActionListener<BulkResponse> listener = invocation.getArgument(1);
405+
BulkResponse bulkResponse = mock(BulkResponse.class);
406+
when(bulkResponse.hasFailures()).thenReturn(false);
407+
listener.onResponse(bulkResponse);
408+
return null;
409+
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
410+
260411
MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);
261412
transportUndeployModelsAction.doExecute(task, request, actionListener);
262413
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
414+
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
263415
}
264416

265417
public void testDoExecute_modelAccessControl_notEnabled() {

0 commit comments

Comments
 (0)