|
13 | 13 | import org.elasticsearch.ElasticsearchStatusException;
|
14 | 14 | import org.elasticsearch.action.ActionListener;
|
15 | 15 | import org.elasticsearch.action.LatchedActionListener;
|
| 16 | +import org.elasticsearch.action.support.ActionTestUtils; |
16 | 17 | import org.elasticsearch.action.support.PlainActionFuture;
|
17 | 18 | import org.elasticsearch.client.internal.Client;
|
18 | 19 | import org.elasticsearch.cluster.service.ClusterService;
|
|
47 | 48 | import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
|
48 | 49 | import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
49 | 50 | import org.elasticsearch.xpack.core.ml.MachineLearningField;
|
| 51 | +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; |
50 | 52 | import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
51 | 53 | import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
52 | 54 | import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
53 | 55 | import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
54 | 56 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
55 | 57 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
|
| 58 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; |
56 | 59 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
57 | 60 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
|
58 | 61 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
|
|
68 | 71 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
69 | 72 | import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
|
70 | 73 | import org.elasticsearch.xpack.inference.services.ServiceFields;
|
| 74 | +import org.hamcrest.Matchers; |
71 | 75 | import org.junit.After;
|
72 | 76 | import org.junit.Before;
|
73 | 77 | import org.mockito.ArgumentCaptor;
|
74 | 78 | import org.mockito.Mockito;
|
75 | 79 |
|
| 80 | +import java.io.IOException; |
76 | 81 | import java.util.ArrayList;
|
77 | 82 | import java.util.Arrays;
|
78 | 83 | import java.util.EnumSet;
|
|
82 | 87 | import java.util.Optional;
|
83 | 88 | import java.util.Set;
|
84 | 89 | import java.util.concurrent.CountDownLatch;
|
| 90 | +import java.util.concurrent.TimeUnit; |
85 | 91 | import java.util.concurrent.atomic.AtomicBoolean;
|
86 | 92 | import java.util.concurrent.atomic.AtomicInteger;
|
87 | 93 | import java.util.concurrent.atomic.AtomicReference;
|
88 | 94 |
|
89 | 95 | import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
90 | 96 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
| 97 | +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; |
91 | 98 | import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
92 | 99 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
|
93 | 100 | import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
|
|
102 | 109 | import static org.mockito.ArgumentMatchers.same;
|
103 | 110 | import static org.mockito.Mockito.doAnswer;
|
104 | 111 | import static org.mockito.Mockito.mock;
|
| 112 | +import static org.mockito.Mockito.verify; |
105 | 113 | import static org.mockito.Mockito.when;
|
106 | 114 |
|
107 | 115 | public class ElasticsearchInternalServiceTests extends ESTestCase {
|
@@ -1698,6 +1706,67 @@ public void testGetConfiguration() throws Exception {
|
1698 | 1706 | }
|
1699 | 1707 | }
|
1700 | 1708 |
|
| 1709 | + public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { |
| 1710 | + var cs = mock(ClusterService.class); |
| 1711 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1712 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1713 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1714 | + mock(), |
| 1715 | + threadPool, |
| 1716 | + cs, |
| 1717 | + Settings.builder().put("xpack.ml.enabled", false).build() |
| 1718 | + ); |
| 1719 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1720 | + var models = List.of(mock(Model.class)); |
| 1721 | + var latch = new CountDownLatch(1); |
| 1722 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> { |
| 1723 | + latch.countDown(); |
| 1724 | + assertThat(r, Matchers.sameInstance(models)); |
| 1725 | + })); |
| 1726 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1727 | + } |
| 1728 | + } |
| 1729 | + |
| 1730 | + public void testUpdateWithMlEnabled() throws IOException, InterruptedException { |
| 1731 | + var deploymentId = "deploymentId"; |
| 1732 | + var model = mock(ElasticsearchInternalModel.class); |
| 1733 | + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1734 | + |
| 1735 | + AssignmentStats stats = mock(); |
| 1736 | + when(stats.getDeploymentId()).thenReturn(deploymentId); |
| 1737 | + when(stats.getNumberOfAllocations()).thenReturn(3); |
| 1738 | + |
| 1739 | + var client = mock(Client.class); |
| 1740 | + doAnswer(ans -> { |
| 1741 | + QueryPage<AssignmentStats> queryPage = new QueryPage<>(List.of(stats), 1, RESULTS_FIELD); |
| 1742 | + |
| 1743 | + GetDeploymentStatsAction.Response response = mock(); |
| 1744 | + when(response.getStats()).thenReturn(queryPage); |
| 1745 | + |
| 1746 | + ActionListener<GetDeploymentStatsAction.Response> listener = ans.getArgument(2); |
| 1747 | + listener.onResponse(response); |
| 1748 | + return null; |
| 1749 | + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); |
| 1750 | + when(client.threadPool()).thenReturn(threadPool); |
| 1751 | + |
| 1752 | + var cs = mock(ClusterService.class); |
| 1753 | + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
| 1754 | + when(cs.getClusterSettings()).thenReturn(cSettings); |
| 1755 | + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( |
| 1756 | + client, |
| 1757 | + threadPool, |
| 1758 | + cs, |
| 1759 | + Settings.builder().put("xpack.ml.enabled", true).build() |
| 1760 | + ); |
| 1761 | + try (var service = new ElasticsearchInternalService(context)) { |
| 1762 | + List<Model> models = List.of(model); |
| 1763 | + var latch = new CountDownLatch(1); |
| 1764 | + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); |
| 1765 | + assertTrue(latch.await(30, TimeUnit.SECONDS)); |
| 1766 | + verify(model).updateNumAllocations(3); |
| 1767 | + } |
| 1768 | + } |
| 1769 | + |
1701 | 1770 | private ElasticsearchInternalService createService(Client client) {
|
1702 | 1771 | var cs = mock(ClusterService.class);
|
1703 | 1772 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
|
|
0 commit comments