Skip to content

Commit 23175fc

Browse files
Fix get all inference endponts not returning multiple endpoints sharing model deployment
1 parent c77afd8 commit 23175fc

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,12 +843,20 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
843843
return;
844844
}
845845

846-
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
846+
var modelsByDeploymentIds = new HashMap<String, List<ElasticsearchInternalModel>>();
847847
for (var model : models) {
848848
assert model instanceof ElasticsearchInternalModel;
849849

850850
if (model instanceof ElasticsearchInternalModel esModel) {
851-
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
851+
if (modelsByDeploymentIds.containsKey(esModel.mlNodeDeploymentId()) == false) {
852+
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), new ArrayList<>() {
853+
{
854+
add(esModel);
855+
}
856+
});
857+
} else {
858+
modelsByDeploymentIds.get(esModel.mlNodeDeploymentId()).add(esModel);
859+
}
852860
} else {
853861
listener.onFailure(
854862
new ElasticsearchStatusException(
@@ -867,10 +875,13 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
867875
new GetDeploymentStatsAction.Request(deploymentIds),
868876
ActionListener.wrap(stats -> {
869877
for (var deploymentStats : stats.getStats().results()) {
870-
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
871-
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
878+
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
879+
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
872880
}
873-
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
881+
var updatedModels = new ArrayList<Model>();
882+
modelsByDeploymentIds.values().forEach(updatedModels::addAll);
883+
884+
listener.onResponse(updatedModels);
874885
}, e -> {
875886
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
876887
// continue with the original response

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@
4646
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
4747
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
4848
import org.elasticsearch.xpack.core.ml.MachineLearningField;
49+
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
4950
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
5051
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
5152
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
5253
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
5354
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
5455
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
56+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
5557
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
5658
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
5759
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
@@ -72,8 +74,10 @@
7274
import org.mockito.ArgumentCaptor;
7375
import org.mockito.Mockito;
7476

77+
import java.io.IOException;
7578
import java.util.ArrayList;
7679
import java.util.Arrays;
80+
import java.util.Collections;
7781
import java.util.EnumSet;
7882
import java.util.HashMap;
7983
import java.util.List;
@@ -101,6 +105,9 @@
101105
import static org.mockito.ArgumentMatchers.same;
102106
import static org.mockito.Mockito.doAnswer;
103107
import static org.mockito.Mockito.mock;
108+
import static org.mockito.Mockito.times;
109+
import static org.mockito.Mockito.verify;
110+
import static org.mockito.Mockito.verifyNoMoreInteractions;
104111
import static org.mockito.Mockito.when;
105112

106113
public class ElasticsearchInternalServiceTests extends ESTestCase {
@@ -1632,6 +1639,148 @@ public void testGetConfiguration() throws Exception {
16321639
}
16331640
}
16341641

1642+
public void testUpdateModelsWithDynamicFields_NoModelsToUpdate() throws Exception {
1643+
ActionListener<List<Model>> resultsListener = ActionListener.<List<Model>>wrap(
1644+
updatedModels -> assertEquals(Collections.emptyList(), updatedModels),
1645+
e -> fail("Unexpected exception: " + e)
1646+
);
1647+
1648+
try (var service = createService(mock(Client.class))) {
1649+
service.updateModelsWithDynamicFields(List.of(), resultsListener);
1650+
}
1651+
}
1652+
1653+
public void testUpdateModelsWithDynamicFields_InvalidModelProvided() throws IOException {
1654+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(
1655+
updatedModels -> fail("Expected invalid model assertion error to be thrown"),
1656+
e -> fail("Expected invalid model assertion error to be thrown")
1657+
);
1658+
1659+
try (var service = createService(mock(Client.class))) {
1660+
assertThrows(
1661+
AssertionError.class,
1662+
() -> { service.updateModelsWithDynamicFields(List.of(mock(Model.class)), resultsListener); }
1663+
);
1664+
}
1665+
}
1666+
1667+
@SuppressWarnings("unchecked")
1668+
public void testUpdateModelsWithDynamicFields_FailsToRetrieveDeployments() throws IOException {
1669+
var deploymentId = randomAlphaOfLength(10);
1670+
var model = mock(ElasticsearchInternalModel.class);
1671+
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
1672+
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1673+
1674+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
1675+
assertEquals(updatedModels.size(), 1);
1676+
verify(model, times(2)).mlNodeDeploymentId();
1677+
verifyNoMoreInteractions(model);
1678+
}, e -> fail("Expected original models to be returned"));
1679+
1680+
var client = mock(Client.class);
1681+
when(client.threadPool()).thenReturn(threadPool);
1682+
doAnswer(invocation -> {
1683+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
1684+
listener.onFailure(new RuntimeException(randomAlphaOfLength(10)));
1685+
return null;
1686+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
1687+
1688+
try (var service = createService(client)) {
1689+
service.updateModelsWithDynamicFields(List.of(model), resultsListener);
1690+
}
1691+
}
1692+
1693+
public void testUpdateModelsWithDynamicFields_SingleModelToUpdate() throws IOException {
1694+
var deploymentId = randomAlphaOfLength(10);
1695+
var model = mock(ElasticsearchInternalModel.class);
1696+
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
1697+
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1698+
1699+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1700+
modelsByDeploymentId.put(deploymentId, List.of(model));
1701+
1702+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1703+
}
1704+
1705+
public void testUpdateModelsWithDynamicFields_MultipleModelsWithDifferentDeploymentsToUpdate() throws IOException {
1706+
var deploymentId1 = randomAlphaOfLength(10);
1707+
var model1 = mock(ElasticsearchInternalModel.class);
1708+
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId1);
1709+
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1710+
var deploymentId2 = randomAlphaOfLength(10);
1711+
var model2 = mock(ElasticsearchInternalModel.class);
1712+
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId2);
1713+
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1714+
1715+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1716+
modelsByDeploymentId.put(deploymentId1, List.of(model1));
1717+
modelsByDeploymentId.put(deploymentId2, List.of(model2));
1718+
1719+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1720+
}
1721+
1722+
public void testUpdateModelsWithDynamicFields_MultipleModelsWithSameDeploymentsToUpdate() throws IOException {
1723+
var deploymentId = randomAlphaOfLength(10);
1724+
var model1 = mock(ElasticsearchInternalModel.class);
1725+
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId);
1726+
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1727+
var model2 = mock(ElasticsearchInternalModel.class);
1728+
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId);
1729+
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1730+
1731+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1732+
modelsByDeploymentId.put(deploymentId, List.of(model1, model2));
1733+
1734+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1735+
}
1736+
1737+
@SuppressWarnings("unchecked")
1738+
private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsByDeploymentId) throws IOException {
1739+
var modelsToUpdate = new ArrayList<Model>();
1740+
modelsByDeploymentId.values().forEach(modelsToUpdate::addAll);
1741+
1742+
var updatedNumberOfAllocations = new HashMap<String, Integer>();
1743+
modelsByDeploymentId.keySet().forEach(deploymentId -> updatedNumberOfAllocations.put(deploymentId, randomIntBetween(1, 10)));
1744+
1745+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
1746+
assertEquals(updatedModels.size(), modelsToUpdate.size());
1747+
modelsByDeploymentId.forEach((deploymentId, models) -> {
1748+
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
1749+
models.forEach(model -> {
1750+
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
1751+
verify((ElasticsearchInternalModel) model, times(2)).mlNodeDeploymentId();
1752+
verifyNoMoreInteractions(model);
1753+
});
1754+
});
1755+
}, e -> fail("Unexpected exception: " + e));
1756+
1757+
var client = mock(Client.class);
1758+
when(client.threadPool()).thenReturn(threadPool);
1759+
doAnswer(invocation -> {
1760+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
1761+
var mockAssignmentStats = new ArrayList<AssignmentStats>();
1762+
modelsByDeploymentId.keySet().forEach(deploymentId -> {
1763+
var mockAssignmentStatsForDeploymentId = mock(AssignmentStats.class);
1764+
when(mockAssignmentStatsForDeploymentId.getDeploymentId()).thenReturn(deploymentId);
1765+
when(mockAssignmentStatsForDeploymentId.getNumberOfAllocations()).thenReturn(updatedNumberOfAllocations.get(deploymentId));
1766+
mockAssignmentStats.add(mockAssignmentStatsForDeploymentId);
1767+
});
1768+
listener.onResponse(
1769+
new GetDeploymentStatsAction.Response(
1770+
Collections.emptyList(),
1771+
Collections.emptyList(),
1772+
mockAssignmentStats,
1773+
mockAssignmentStats.size()
1774+
)
1775+
);
1776+
return null;
1777+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
1778+
1779+
try (var service = createService(client)) {
1780+
service.updateModelsWithDynamicFields(modelsToUpdate, resultsListener);
1781+
}
1782+
}
1783+
16351784
private ElasticsearchInternalService createService(Client client) {
16361785
var cs = mock(ClusterService.class);
16371786
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)