|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.action.undeploy; |
7 | 7 |
|
8 | | -import static org.mockito.ArgumentMatchers.any; |
9 | | -import static org.mockito.ArgumentMatchers.isA; |
10 | | -import static org.mockito.Mockito.doAnswer; |
11 | | -import static org.mockito.Mockito.doReturn; |
12 | | -import static org.mockito.Mockito.doThrow; |
13 | | -import static org.mockito.Mockito.mock; |
14 | | -import static org.mockito.Mockito.never; |
15 | | -import static org.mockito.Mockito.spy; |
16 | | -import static org.mockito.Mockito.verify; |
17 | | -import static org.mockito.Mockito.when; |
18 | | -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; |
19 | | -import static org.opensearch.ml.common.CommonValue.NOT_FOUND; |
20 | | -import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; |
21 | | - |
22 | | -import java.io.IOException; |
23 | | -import java.util.ArrayList; |
24 | | -import java.util.HashMap; |
25 | | -import java.util.List; |
26 | | -import java.util.Map; |
27 | | - |
28 | 8 | import org.junit.Before; |
29 | 9 | import org.junit.Rule; |
30 | 10 | import org.junit.rules.ExpectedException; |
31 | 11 | import org.mockito.ArgumentCaptor; |
32 | 12 | import org.mockito.Mock; |
| 13 | +import org.mockito.Mockito; |
33 | 14 | import org.mockito.MockitoAnnotations; |
| 15 | +import org.opensearch.action.DocWriteRequest; |
| 16 | +import org.opensearch.action.DocWriteResponse; |
34 | 17 | import org.opensearch.action.FailedNodeException; |
| 18 | +import org.opensearch.action.bulk.BulkItemResponse; |
35 | 19 | import org.opensearch.action.bulk.BulkRequest; |
36 | 20 | import org.opensearch.action.bulk.BulkResponse; |
37 | 21 | import org.opensearch.action.support.ActionFilters; |
38 | 22 | import org.opensearch.action.update.UpdateRequest; |
| 23 | +import org.opensearch.action.update.UpdateResponse; |
| 24 | +import org.opensearch.client.Client; |
39 | 25 | import org.opensearch.cluster.ClusterName; |
40 | 26 | import org.opensearch.cluster.service.ClusterService; |
41 | 27 | import org.opensearch.common.settings.Settings; |
42 | 28 | import org.opensearch.common.util.concurrent.ThreadContext; |
43 | 29 | import org.opensearch.commons.ConfigConstants; |
44 | 30 | import org.opensearch.commons.authuser.User; |
45 | 31 | import org.opensearch.core.action.ActionListener; |
| 32 | +import org.opensearch.core.index.shard.ShardId; |
46 | 33 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
47 | 34 | import org.opensearch.ml.cluster.DiscoveryNodeHelper; |
48 | 35 | import org.opensearch.ml.common.FunctionName; |
|
59 | 46 | import org.opensearch.ml.task.MLTaskDispatcher; |
60 | 47 | import org.opensearch.ml.task.MLTaskManager; |
61 | 48 | import org.opensearch.remote.metadata.client.SdkClient; |
| 49 | +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; |
62 | 50 | import org.opensearch.tasks.Task; |
63 | 51 | import org.opensearch.test.OpenSearchTestCase; |
64 | 52 | import org.opensearch.threadpool.ThreadPool; |
65 | 53 | import org.opensearch.transport.TransportService; |
66 | | -import org.opensearch.transport.client.Client; |
| 54 | + |
| 55 | +import java.io.IOException; |
| 56 | +import java.util.ArrayList; |
| 57 | +import java.util.Collections; |
| 58 | +import java.util.HashMap; |
| 59 | +import java.util.List; |
| 60 | +import java.util.Map; |
| 61 | + |
| 62 | +import static org.mockito.ArgumentMatchers.any; |
| 63 | +import static org.mockito.ArgumentMatchers.isA; |
| 64 | +import static org.mockito.Mockito.doAnswer; |
| 65 | +import static org.mockito.Mockito.doReturn; |
| 66 | +import static org.mockito.Mockito.doThrow; |
| 67 | +import static org.mockito.Mockito.mock; |
| 68 | +import static org.mockito.Mockito.never; |
| 69 | +import static org.mockito.Mockito.spy; |
| 70 | +import static org.mockito.Mockito.verify; |
| 71 | +import static org.mockito.Mockito.when; |
| 72 | +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; |
| 73 | +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; |
| 74 | +import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; |
67 | 75 |
|
68 | 76 | public class TransportUndeployModelsActionTests extends OpenSearchTestCase { |
69 | 77 |
|
@@ -133,6 +141,8 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { |
133 | 141 | public void setup() throws IOException { |
134 | 142 | MockitoAnnotations.openMocks(this); |
135 | 143 | Settings settings = Settings.builder().build(); |
| 144 | + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); |
| 145 | + |
136 | 146 | transportUndeployModelsAction = spy( |
137 | 147 | new TransportUndeployModelsAction( |
138 | 148 | transportService, |
@@ -215,11 +225,10 @@ public void testDoExecute_undeployModelIndex_WhenNoNodesServiceModel() { |
215 | 225 |
|
216 | 226 | ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); |
217 | 227 |
|
| 228 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
218 | 229 | // mock the bulk response that can be captured for inspecting the contents of the write to index |
219 | 230 | doAnswer(invocation -> { |
220 | 231 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
221 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
222 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
223 | 232 | listener.onResponse(bulkResponse); |
224 | 233 | return null; |
225 | 234 | }).when(client).bulk(bulkRequestCaptor.capture(), any(ActionListener.class)); |
@@ -331,11 +340,10 @@ public void testHiddenModelSuccess() { |
331 | 340 | return null; |
332 | 341 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
333 | 342 |
|
| 343 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
334 | 344 | // Mock the client.bulk call |
335 | 345 | doAnswer(invocation -> { |
336 | 346 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
337 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
338 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
339 | 347 | listener.onResponse(bulkResponse); |
340 | 348 | return null; |
341 | 349 | }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
@@ -390,9 +398,10 @@ public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() { |
390 | 398 | return null; |
391 | 399 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
392 | 400 |
|
| 401 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
393 | 402 | doAnswer(invocation -> { |
394 | 403 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
395 | | - listener.onResponse(mock(BulkResponse.class)); |
| 404 | + listener.onResponse(bulkResponse); |
396 | 405 | return null; |
397 | 406 | }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
398 | 407 |
|
@@ -456,17 +465,18 @@ public void testDoExecute() { |
456 | 465 | listener.onResponse(response); |
457 | 466 | return null; |
458 | 467 | }).when(client).execute(any(), any(), isA(ActionListener.class)); |
459 | | - // Mock the client.bulk call |
| 468 | + |
| 469 | + BulkResponse bulkResponse = getSuccessBulkResponse(); |
460 | 470 | doAnswer(invocation -> { |
461 | 471 | ActionListener<BulkResponse> listener = invocation.getArgument(1); |
462 | | - BulkResponse bulkResponse = mock(BulkResponse.class); |
463 | | - when(bulkResponse.hasFailures()).thenReturn(false); |
464 | 472 | listener.onResponse(bulkResponse); |
465 | 473 | return null; |
466 | | - }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
| 474 | + }).when(client).bulk(any(), any()); |
467 | 475 |
|
468 | 476 | MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); |
469 | 477 | transportUndeployModelsAction.doExecute(task, request, actionListener); |
| 478 | + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); |
| 479 | + |
470 | 480 | verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); |
471 | 481 | verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); |
472 | 482 | } |
@@ -532,4 +542,17 @@ public void testDoExecute_modelIds_moreThan1() { |
532 | 542 | MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null); |
533 | 543 | transportUndeployModelsAction.doExecute(task, request, actionListener); |
534 | 544 | } |
| 545 | + |
| 546 | + private BulkResponse getSuccessBulkResponse() { |
| 547 | + return new BulkResponse( |
| 548 | + new BulkItemResponse[]{ |
| 549 | + new BulkItemResponse( |
| 550 | + 1, |
| 551 | + DocWriteRequest.OpType.UPDATE, |
| 552 | + new UpdateResponse(new ShardId(ML_MODEL_INDEX, "modelId123", 0), "id1", 1, 1, 1, DocWriteResponse.Result.UPDATED) |
| 553 | + ) |
| 554 | + }, |
| 555 | + 100L |
| 556 | + ); |
| 557 | + } |
535 | 558 | } |
0 commit comments