|
13 | 13 | import org.elasticsearch.ElasticsearchStatusException; |
14 | 14 | import org.elasticsearch.action.ActionListener; |
15 | 15 | import org.elasticsearch.action.LatchedActionListener; |
| 16 | +import org.elasticsearch.action.support.ActionTestUtils; |
16 | 17 | import org.elasticsearch.action.support.PlainActionFuture; |
17 | 18 | import org.elasticsearch.client.internal.Client; |
18 | 19 | import org.elasticsearch.cluster.service.ClusterService; |
|
46 | 47 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; |
47 | 48 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; |
48 | 49 | import org.elasticsearch.xpack.core.ml.MachineLearningField; |
| 50 | +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; |
49 | 51 | import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; |
50 | 52 | import org.elasticsearch.xpack.core.ml.action.InferModelAction; |
51 | 53 | import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; |
52 | 54 | import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; |
53 | 55 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; |
54 | 56 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 57 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; |
55 | 58 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; |
56 | 59 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; |
57 | 60 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; |
|
67 | 70 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; |
68 | 71 | import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; |
69 | 72 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
| 73 | +import org.hamcrest.Matchers; |
70 | 74 | import org.junit.After; |
71 | 75 | import org.junit.Before; |
72 | 76 | import org.mockito.ArgumentCaptor; |
73 | 77 | import org.mockito.Mockito; |
74 | 78 |
|
| 79 | +import java.io.IOException; |
75 | 80 | import java.util.ArrayList; |
76 | 81 | import java.util.Arrays; |
77 | 82 | import java.util.EnumSet; |
|
81 | 86 | import java.util.Optional; |
82 | 87 | import java.util.Set; |
83 | 88 | import java.util.concurrent.CountDownLatch; |
| 89 | +import java.util.concurrent.TimeUnit; |
84 | 90 | import java.util.concurrent.atomic.AtomicBoolean; |
85 | 91 | import java.util.concurrent.atomic.AtomicInteger; |
86 | 92 | import java.util.concurrent.atomic.AtomicReference; |
87 | 93 |
|
88 | 94 | import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; |
89 | 95 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; |
| 96 | +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; |
90 | 97 | import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; |
91 | 98 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; |
92 | 99 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; |
|
101 | 108 | import static org.mockito.ArgumentMatchers.same; |
102 | 109 | import static org.mockito.Mockito.doAnswer; |
103 | 110 | import static org.mockito.Mockito.mock; |
| 111 | +import static org.mockito.Mockito.verify; |
104 | 112 | import static org.mockito.Mockito.when; |
105 | 113 |
|
106 | 114 | public class ElasticsearchInternalServiceTests extends ESTestCase { |
@@ -1632,6 +1640,67 @@ public void testGetConfiguration() throws Exception { |
1632 | 1640 | } |
1633 | 1641 | } |
1634 | 1642 |
|
| 1643 | + public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { |
| 1644 | + var cs = mock(ClusterService.class); |
| 1645 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1646 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1647 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1648 | + mock(), |
| 1649 | + threadPool, |
| 1650 | + cs, |
| 1651 | + Settings.builder().put("xpack.ml.enabled", false).build() |
| 1652 | + ); |
| 1653 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1654 | + var models = List.of(mock(Model.class)); |
| 1655 | + var latch = new CountDownLatch(1); |
| 1656 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> { |
| 1657 | + latch.countDown(); |
| 1658 | + assertThat(r, Matchers.sameInstance(models)); |
| 1659 | + })); |
| 1660 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1661 | + } |
| 1662 | + } |
| 1663 | + |
| 1664 | + public void testUpdateWithMlEnabled() throws IOException, InterruptedException { |
| 1665 | + var deploymentId = "deploymentId"; |
| 1666 | + var model = mock(ElasticsearchInternalModel.class); |
| 1667 | + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1668 | + |
| 1669 | + AssignmentStats stats = mock(); |
| 1670 | + when(stats.getDeploymentId()).thenReturn(deploymentId); |
| 1671 | + when(stats.getNumberOfAllocations()).thenReturn(3); |
| 1672 | + |
| 1673 | + var client = mock(Client.class); |
| 1674 | + doAnswer(ans -> { |
| 1675 | + QueryPage<AssignmentStats> queryPage = new QueryPage<>(List.of(stats), 1, RESULTS_FIELD); |
| 1676 | + |
| 1677 | + GetDeploymentStatsAction.Response response = mock(); |
| 1678 | + when(response.getStats()).thenReturn(queryPage); |
| 1679 | + |
| 1680 | + ActionListener<GetDeploymentStatsAction.Response> listener = ans.getArgument(2); |
| 1681 | + listener.onResponse(response); |
| 1682 | + return null; |
| 1683 | + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); |
| 1684 | + when(client.threadPool()).thenReturn(threadPool); |
| 1685 | + |
| 1686 | + var cs = mock(ClusterService.class); |
| 1687 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1688 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1689 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1690 | + client, |
| 1691 | + threadPool, |
| 1692 | + cs, |
| 1693 | + Settings.builder().put("xpack.ml.enabled", true).build() |
| 1694 | + ); |
| 1695 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1696 | + List<Model> models = List.of(model); |
| 1697 | + var latch = new CountDownLatch(1); |
| 1698 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); |
| 1699 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1700 | + verify(model).updateNumAllocations(3); |
| 1701 | + } |
| 1702 | + } |
| 1703 | + |
1635 | 1704 | private ElasticsearchInternalService createService(Client client) { |
1636 | 1705 | var cs = mock(ClusterService.class); |
1637 | 1706 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
|
0 commit comments