|
13 | 13 | import org.elasticsearch.ElasticsearchStatusException; |
14 | 14 | import org.elasticsearch.action.ActionListener; |
15 | 15 | import org.elasticsearch.action.LatchedActionListener; |
| 16 | +import org.elasticsearch.action.support.ActionTestUtils; |
16 | 17 | import org.elasticsearch.action.support.PlainActionFuture; |
17 | 18 | import org.elasticsearch.client.internal.Client; |
18 | 19 | import org.elasticsearch.cluster.service.ClusterService; |
|
47 | 48 | import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; |
48 | 49 | import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; |
49 | 50 | import org.elasticsearch.xpack.core.ml.MachineLearningField; |
| 51 | +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; |
50 | 52 | import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; |
51 | 53 | import org.elasticsearch.xpack.core.ml.action.InferModelAction; |
52 | 54 | import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; |
53 | 55 | import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; |
54 | 56 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; |
55 | 57 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 58 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; |
56 | 59 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; |
57 | 60 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; |
58 | 61 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; |
|
68 | 71 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; |
69 | 72 | import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; |
70 | 73 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
| 74 | +import org.hamcrest.Matchers; |
71 | 75 | import org.junit.After; |
72 | 76 | import org.junit.Before; |
73 | 77 | import org.mockito.ArgumentCaptor; |
74 | 78 | import org.mockito.Mockito; |
75 | 79 |
|
| 80 | +import java.io.IOException; |
76 | 81 | import java.util.ArrayList; |
77 | 82 | import java.util.Arrays; |
78 | 83 | import java.util.EnumSet; |
|
82 | 87 | import java.util.Optional; |
83 | 88 | import java.util.Set; |
84 | 89 | import java.util.concurrent.CountDownLatch; |
| 90 | +import java.util.concurrent.TimeUnit; |
85 | 91 | import java.util.concurrent.atomic.AtomicBoolean; |
86 | 92 | import java.util.concurrent.atomic.AtomicInteger; |
87 | 93 | import java.util.concurrent.atomic.AtomicReference; |
88 | 94 |
|
89 | 95 | import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; |
90 | 96 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; |
| 97 | +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; |
91 | 98 | import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; |
92 | 99 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; |
93 | 100 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; |
|
102 | 109 | import static org.mockito.ArgumentMatchers.same; |
103 | 110 | import static org.mockito.Mockito.doAnswer; |
104 | 111 | import static org.mockito.Mockito.mock; |
| 112 | +import static org.mockito.Mockito.verify; |
105 | 113 | import static org.mockito.Mockito.when; |
106 | 114 |
|
107 | 115 | public class ElasticsearchInternalServiceTests extends ESTestCase { |
@@ -1698,6 +1706,67 @@ public void testGetConfiguration() throws Exception { |
1698 | 1706 | } |
1699 | 1707 | } |
1700 | 1708 |
|
| 1709 | + public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { |
| 1710 | + var cs = mock(ClusterService.class); |
| 1711 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1712 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1713 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1714 | + mock(), |
| 1715 | + threadPool, |
| 1716 | + cs, |
| 1717 | + Settings.builder().put("xpack.ml.enabled", false).build() |
| 1718 | + ); |
| 1719 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1720 | + var models = List.of(mock(Model.class)); |
| 1721 | + var latch = new CountDownLatch(1); |
| 1722 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> { |
| 1723 | + latch.countDown(); |
| 1724 | + assertThat(r, Matchers.sameInstance(models)); |
| 1725 | + })); |
| 1726 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1727 | + } |
| 1728 | + } |
| 1729 | + |
| 1730 | + public void testUpdateWithMlEnabled() throws IOException, InterruptedException { |
| 1731 | + var deploymentId = "deploymentId"; |
| 1732 | + var model = mock(ElasticsearchInternalModel.class); |
| 1733 | + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1734 | + |
| 1735 | + AssignmentStats stats = mock(); |
| 1736 | + when(stats.getDeploymentId()).thenReturn(deploymentId); |
| 1737 | + when(stats.getNumberOfAllocations()).thenReturn(3); |
| 1738 | + |
| 1739 | + var client = mock(Client.class); |
| 1740 | + doAnswer(ans -> { |
| 1741 | + QueryPage<AssignmentStats> queryPage = new QueryPage<>(List.of(stats), 1, RESULTS_FIELD); |
| 1742 | + |
| 1743 | + GetDeploymentStatsAction.Response response = mock(); |
| 1744 | + when(response.getStats()).thenReturn(queryPage); |
| 1745 | + |
| 1746 | + ActionListener<GetDeploymentStatsAction.Response> listener = ans.getArgument(2); |
| 1747 | + listener.onResponse(response); |
| 1748 | + return null; |
| 1749 | + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); |
| 1750 | + when(client.threadPool()).thenReturn(threadPool); |
| 1751 | + |
| 1752 | + var cs = mock(ClusterService.class); |
| 1753 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1754 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1755 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1756 | + client, |
| 1757 | + threadPool, |
| 1758 | + cs, |
| 1759 | + Settings.builder().put("xpack.ml.enabled", true).build() |
| 1760 | + ); |
| 1761 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1762 | + List<Model> models = List.of(model); |
| 1763 | + var latch = new CountDownLatch(1); |
| 1764 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); |
| 1765 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1766 | + verify(model).updateNumAllocations(3); |
| 1767 | + } |
| 1768 | + } |
| 1769 | + |
1701 | 1770 | private ElasticsearchInternalService createService(Client client) { |
1702 | 1771 | var cs = mock(ClusterService.class); |
1703 | 1772 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
|
0 commit comments