From f717b6654074d3ddc5ed3706409ef35cd0e30b3a Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Mon, 10 Feb 2025 12:40:37 -0500 Subject: [PATCH] Fix get all inference endponts not returning multiple endpoints sharing model deployment (#121821) * Fix get all inference endponts not returning multiple endpoints sharing model deployment * Update docs/changelog/121821.yaml * Clean up modelsByDeploymentId generation code --------- Co-authored-by: Elastic Machine --- docs/changelog/121821.yaml | 6 + .../ElasticsearchInternalService.java | 16 +- .../ElasticsearchInternalServiceTests.java | 144 ++++++++++++++++++ 3 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 docs/changelog/121821.yaml diff --git a/docs/changelog/121821.yaml b/docs/changelog/121821.yaml new file mode 100644 index 0000000000000..1e8edd09dcd9a --- /dev/null +++ b/docs/changelog/121821.yaml @@ -0,0 +1,6 @@ +pr: 121821 +summary: Fix get all inference endponts not returning multiple endpoints sharing model + deployment +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index c7f19adb269ac..ddc5e3e1aa36c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -854,12 +854,15 @@ public void updateModelsWithDynamicFields(List models, ActionListener(); + var modelsByDeploymentIds = new HashMap>(); for (var model : models) { assert model instanceof ElasticsearchInternalModel; if (model instanceof ElasticsearchInternalModel esModel) { - modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel); + modelsByDeploymentIds.merge(esModel.mlNodeDeploymentId(), new ArrayList<>(List.of(esModel)), (a, b) -> { + a.addAll(b); + return a; + }); } else { listener.onFailure( new ElasticsearchStatusException( @@ -878,10 +881,13 @@ public void updateModelsWithDynamicFields(List models, ActionListener { for (var deploymentStats : stats.getStats().results()) { - var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId()); - model.updateNumAllocations(deploymentStats.getNumberOfAllocations()); + var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId()); + modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations())); } - listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values())); + var updatedModels = new ArrayList(); + modelsByDeploymentIds.values().forEach(updatedModels::addAll); + + listener.onResponse(updatedModels); }, e -> { logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e); // continue with the original response diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 580871bb2c9a7..d1ce79b863c61 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -79,6 +79,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.List; @@ -109,6 +110,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class ElasticsearchInternalServiceTests extends ESTestCase { @@ -1640,6 +1642,148 @@ public void testGetConfiguration() throws Exception { } } + public void testUpdateModelsWithDynamicFields_NoModelsToUpdate() throws Exception { + ActionListener> resultsListener = ActionListener.>wrap( + updatedModels -> assertEquals(Collections.emptyList(), updatedModels), + e -> fail("Unexpected exception: " + e) + ); + + try (var service = createService(mock(Client.class))) { + service.updateModelsWithDynamicFields(List.of(), resultsListener); + } + } + + public void testUpdateModelsWithDynamicFields_InvalidModelProvided() throws IOException { + ActionListener> resultsListener = ActionListener.wrap( + updatedModels -> fail("Expected invalid model assertion error to be thrown"), + e -> fail("Expected invalid model assertion error to be thrown") + ); + + try (var service = createService(mock(Client.class))) { + assertThrows( + AssertionError.class, + () -> { service.updateModelsWithDynamicFields(List.of(mock(Model.class)), resultsListener); } + ); + } + } + + @SuppressWarnings("unchecked") + public void testUpdateModelsWithDynamicFields_FailsToRetrieveDeployments() throws IOException { + var deploymentId = randomAlphaOfLength(10); + var model = mock(ElasticsearchInternalModel.class); + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + ActionListener> resultsListener = ActionListener.wrap(updatedModels -> { + assertEquals(updatedModels.size(), 1); + verify(model).mlNodeDeploymentId(); + verifyNoMoreInteractions(model); + }, e -> fail("Expected original models to be returned")); + + var client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocation -> { + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new RuntimeException(randomAlphaOfLength(10))); + return null; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + + try (var service = createService(client)) { + service.updateModelsWithDynamicFields(List.of(model), resultsListener); + } + } + + public void testUpdateModelsWithDynamicFields_SingleModelToUpdate() throws IOException { + var deploymentId = randomAlphaOfLength(10); + var model = mock(ElasticsearchInternalModel.class); + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + var modelsByDeploymentId = new HashMap>(); + modelsByDeploymentId.put(deploymentId, List.of(model)); + + testUpdateModelsWithDynamicFields(modelsByDeploymentId); + } + + public void testUpdateModelsWithDynamicFields_MultipleModelsWithDifferentDeploymentsToUpdate() throws IOException { + var deploymentId1 = randomAlphaOfLength(10); + var model1 = mock(ElasticsearchInternalModel.class); + when(model1.mlNodeDeploymentId()).thenReturn(deploymentId1); + when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + var deploymentId2 = randomAlphaOfLength(10); + var model2 = mock(ElasticsearchInternalModel.class); + when(model2.mlNodeDeploymentId()).thenReturn(deploymentId2); + when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + var modelsByDeploymentId = new HashMap>(); + modelsByDeploymentId.put(deploymentId1, List.of(model1)); + modelsByDeploymentId.put(deploymentId2, List.of(model2)); + + testUpdateModelsWithDynamicFields(modelsByDeploymentId); + } + + public void testUpdateModelsWithDynamicFields_MultipleModelsWithSameDeploymentsToUpdate() throws IOException { + var deploymentId = randomAlphaOfLength(10); + var model1 = mock(ElasticsearchInternalModel.class); + when(model1.mlNodeDeploymentId()).thenReturn(deploymentId); + when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + var model2 = mock(ElasticsearchInternalModel.class); + when(model2.mlNodeDeploymentId()).thenReturn(deploymentId); + when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); + + var modelsByDeploymentId = new HashMap>(); + modelsByDeploymentId.put(deploymentId, List.of(model1, model2)); + + testUpdateModelsWithDynamicFields(modelsByDeploymentId); + } + + @SuppressWarnings("unchecked") + private void testUpdateModelsWithDynamicFields(Map> modelsByDeploymentId) throws IOException { + var modelsToUpdate = new ArrayList(); + modelsByDeploymentId.values().forEach(modelsToUpdate::addAll); + + var updatedNumberOfAllocations = new HashMap(); + modelsByDeploymentId.keySet().forEach(deploymentId -> updatedNumberOfAllocations.put(deploymentId, randomIntBetween(1, 10))); + + ActionListener> resultsListener = ActionListener.wrap(updatedModels -> { + assertEquals(updatedModels.size(), modelsToUpdate.size()); + modelsByDeploymentId.forEach((deploymentId, models) -> { + var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId); + models.forEach(model -> { + verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations); + verify((ElasticsearchInternalModel) model).mlNodeDeploymentId(); + verifyNoMoreInteractions(model); + }); + }); + }, e -> fail("Unexpected exception: " + e)); + + var client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocation -> { + var listener = (ActionListener) invocation.getArguments()[2]; + var mockAssignmentStats = new ArrayList(); + modelsByDeploymentId.keySet().forEach(deploymentId -> { + var mockAssignmentStatsForDeploymentId = mock(AssignmentStats.class); + when(mockAssignmentStatsForDeploymentId.getDeploymentId()).thenReturn(deploymentId); + when(mockAssignmentStatsForDeploymentId.getNumberOfAllocations()).thenReturn(updatedNumberOfAllocations.get(deploymentId)); + mockAssignmentStats.add(mockAssignmentStatsForDeploymentId); + }); + listener.onResponse( + new GetDeploymentStatsAction.Response( + Collections.emptyList(), + Collections.emptyList(), + mockAssignmentStats, + mockAssignmentStats.size() + ) + ); + return null; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + + try (var service = createService(client)) { + service.updateModelsWithDynamicFields(modelsToUpdate, resultsListener); + } + } + public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));