|
7 | 7 |
|
8 | 8 | package org.opensearch.ml.action.undeploy; |
9 | 9 |
|
10 | | -import static org.mockito.ArgumentMatchers.any; |
11 | | -import static org.mockito.ArgumentMatchers.isA; |
12 | | -import static org.mockito.Mockito.doAnswer; |
13 | | -import static org.mockito.Mockito.doReturn; |
14 | | -import static org.mockito.Mockito.doThrow; |
15 | | -import static org.mockito.Mockito.mock; |
16 | | -import static org.mockito.Mockito.never; |
17 | | -import static org.mockito.Mockito.spy; |
18 | | -import static org.mockito.Mockito.verify; |
19 | | -import static org.mockito.Mockito.when; |
20 | | -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; |
21 | | -import static org.opensearch.ml.common.CommonValue.NOT_FOUND; |
22 | | -import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; |
23 | | - |
24 | | -import java.io.IOException; |
25 | | -import java.util.ArrayList; |
26 | | -import java.util.HashMap; |
27 | | -import java.util.List; |
28 | | -import java.util.Map; |
29 | | - |
30 | 10 | import org.junit.Before; |
31 | 11 | import org.junit.Rule; |
32 | 12 | import org.junit.rules.ExpectedException; |
33 | 13 | import org.mockito.ArgumentCaptor; |
34 | 14 | import org.mockito.Mock; |
| 15 | +import org.mockito.Mockito; |
35 | 16 | import org.mockito.MockitoAnnotations; |
| 17 | +import org.opensearch.action.DocWriteRequest; |
| 18 | +import org.opensearch.action.DocWriteResponse; |
36 | 19 | import org.opensearch.action.FailedNodeException; |
| 20 | +import org.opensearch.action.bulk.BulkItemResponse; |
37 | 21 | import org.opensearch.action.bulk.BulkRequest; |
38 | 22 | import org.opensearch.action.bulk.BulkResponse; |
39 | 23 | import org.opensearch.action.support.ActionFilters; |
40 | 24 | import org.opensearch.action.update.UpdateRequest; |
| 25 | +import org.opensearch.action.update.UpdateResponse; |
41 | 26 | import org.opensearch.client.Client; |
42 | 27 | import org.opensearch.cluster.ClusterName; |
43 | 28 | import org.opensearch.cluster.service.ClusterService; |
|
46 | 31 | import org.opensearch.commons.ConfigConstants; |
47 | 32 | import org.opensearch.commons.authuser.User; |
48 | 33 | import org.opensearch.core.action.ActionListener; |
| 34 | +import org.opensearch.core.index.shard.ShardId; |
49 | 35 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
50 | 36 | import org.opensearch.ml.cluster.DiscoveryNodeHelper; |
51 | 37 | import org.opensearch.ml.common.FunctionName; |
|
62 | 48 | import org.opensearch.ml.task.MLTaskDispatcher; |
63 | 49 | import org.opensearch.ml.task.MLTaskManager; |
64 | 50 | import org.opensearch.remote.metadata.client.SdkClient; |
| 51 | +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; |
65 | 52 | import org.opensearch.tasks.Task; |
66 | 53 | import org.opensearch.test.OpenSearchTestCase; |
67 | 54 | import org.opensearch.threadpool.ThreadPool; |
68 | 55 | import org.opensearch.transport.TransportService; |
69 | 56 |
|
| 57 | +import java.io.IOException; |
| 58 | +import java.util.ArrayList; |
| 59 | +import java.util.Collections; |
| 60 | +import java.util.HashMap; |
| 61 | +import java.util.List; |
| 62 | +import java.util.Map; |
| 63 | + |
| 64 | +import static org.mockito.ArgumentMatchers.any; |
| 65 | +import static org.mockito.ArgumentMatchers.isA; |
| 66 | +import static org.mockito.Mockito.doAnswer; |
| 67 | +import static org.mockito.Mockito.doReturn; |
| 68 | +import static org.mockito.Mockito.doThrow; |
| 69 | +import static org.mockito.Mockito.mock; |
| 70 | +import static org.mockito.Mockito.never; |
| 71 | +import static org.mockito.Mockito.spy; |
| 72 | +import static org.mockito.Mockito.verify; |
| 73 | +import static org.mockito.Mockito.when; |
| 74 | +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; |
| 75 | +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; |
| 76 | +import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; |
| 77 | + |
70 | 78 | public class TransportUndeployModelsActionTests extends OpenSearchTestCase { |
71 | 79 |
|
72 | 80 | @Mock |
@@ -135,6 +143,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { |
135 | 143 | public void setup() throws IOException { |
136 | 144 | MockitoAnnotations.openMocks(this); |
137 | 145 | Settings settings = Settings.builder().build(); |
| 146 | + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); |
| 147 | + |
138 | 148 | transportUndeployModelsAction = spy( |
139 | 149 | new TransportUndeployModelsAction( |
140 | 150 | transportService, |
@@ -217,11 +227,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() { |
217 | 227 |
|
218 | 228 | ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); |
219 | 229 |
|
| 230 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
220 | 231 | // mock the bulk response that can be captured for inspecting the contents of the write to index |
221 | 232 | doAnswer(invocation -> { |
222 | 233 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
223 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
224 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
225 | 234 | listener.onResponse(bulkResponse); |
226 | 235 | return null; |
227 | 236 | }).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class)); |
@@ -333,11 +342,10 @@ public void testHiddenModelSuccess() { |
333 | 342 | return null; |
334 | 343 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
335 | 344 |
|
| 345 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
336 | 346 | // Mock the client.bulk call |
337 | 347 | doAnswer(invocation -> { |
338 | 348 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
339 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
340 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
341 | 349 | listener.onResponse(bulkResponse); |
342 | 350 | return null; |
343 | 351 | }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
@@ -392,9 +400,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() { |
392 | 400 | return null; |
393 | 401 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
394 | 402 |
|
| 403 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
395 | 404 | doAnswer(invocation -> { |
396 | 405 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
397 | | - listener.onResponse(mock(BulkResponse.class)); |
| 406 | + listener.onResponse(bulkResponse); |
398 | 407 | return null; |
399 | 408 | }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
400 | 409 |
|
@@ -458,17 +467,18 @@ public void testDoExecute() { |
458 | 467 | listener.onResponse(response); |
459 | 468 | return null; |
460 | 469 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
461 | | - // Mock the client.bulk call |
| 470 | + |
| 471 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
462 | 472 | doAnswer(invocation -> { |
463 | 473 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
464 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
465 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
466 | 474 | listener.onResponse(bulkResponse); |
467 | 475 | return null; |
468 | | - }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
| 476 | + }).when(client).bulk(any(), any()); |
469 | 477 |
|
470 | 478 | MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); |
471 | 479 | transportUndeployModelsAction.doExecute(task, request, actionListener); |
| 480 | + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); |
| 481 | + |
472 | 482 | verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); |
473 | 483 | verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
474 | 484 | } |
@@ -534,4 +544,17 @@ public void testDoExecute_modelIds_moreThan1() { |
534 | 544 | MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null); |
535 | 545 | transportUndeployModelsAction.doExecute(task, request, actionListener); |
536 | 546 | } |
| 547 | + |
| 548 | + private BulkResponse getSuccessBulkResponse() { |
| 549 | + return new BulkResponse( |
| 550 | + new BulkItemResponse[]{ |
| 551 | + new BulkItemResponse( |
| 552 | + 1, |
| 553 | + DocWriteRequest.OpType.UPDATE, |
| 554 | + new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED) |
| 555 | + ) |
| 556 | + }, |
| 557 | + 100L |
| 558 | + ); |
| 559 | + } |
537 | 560 | } |
0 commit comments