Skip to content

Commit 20b720b

Browse files
Fix get all inference endponts not returning multiple endpoints sharing model deployment (#121821) (#122206)
* Fix get all inference endponts not returning multiple endpoints sharing model deployment * Update docs/changelog/121821.yaml * Clean up modelsByDeploymentId generation code --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent f078c38 commit 20b720b

File tree

3 files changed

+161
-5
lines changed

3 files changed

+161
-5
lines changed

docs/changelog/121821.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 121821
2+
summary: Fix get all inference endponts not returning multiple endpoints sharing model
3+
deployment
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -854,12 +854,15 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
854854
return;
855855
}
856856

857-
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
857+
var modelsByDeploymentIds = new HashMap<String, List<ElasticsearchInternalModel>>();
858858
for (var model : models) {
859859
assert model instanceof ElasticsearchInternalModel;
860860

861861
if (model instanceof ElasticsearchInternalModel esModel) {
862-
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
862+
modelsByDeploymentIds.merge(esModel.mlNodeDeploymentId(), new ArrayList<>(List.of(esModel)), (a, b) -> {
863+
a.addAll(b);
864+
return a;
865+
});
863866
} else {
864867
listener.onFailure(
865868
new ElasticsearchStatusException(
@@ -878,10 +881,13 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
878881
new GetDeploymentStatsAction.Request(deploymentIds),
879882
ActionListener.wrap(stats -> {
880883
for (var deploymentStats : stats.getStats().results()) {
881-
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
882-
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
884+
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
885+
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
883886
}
884-
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
887+
var updatedModels = new ArrayList<Model>();
888+
modelsByDeploymentIds.values().forEach(updatedModels::addAll);
889+
890+
listener.onResponse(updatedModels);
885891
}, e -> {
886892
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
887893
// continue with the original response

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
import java.io.IOException;
8080
import java.util.ArrayList;
8181
import java.util.Arrays;
82+
import java.util.Collections;
8283
import java.util.EnumSet;
8384
import java.util.HashMap;
8485
import java.util.List;
@@ -109,6 +110,7 @@
109110
import static org.mockito.Mockito.doAnswer;
110111
import static org.mockito.Mockito.mock;
111112
import static org.mockito.Mockito.verify;
113+
import static org.mockito.Mockito.verifyNoMoreInteractions;
112114
import static org.mockito.Mockito.when;
113115

114116
public class ElasticsearchInternalServiceTests extends ESTestCase {
@@ -1640,6 +1642,148 @@ public void testGetConfiguration() throws Exception {
16401642
}
16411643
}
16421644

1645+
public void testUpdateModelsWithDynamicFields_NoModelsToUpdate() throws Exception {
1646+
ActionListener<List<Model>> resultsListener = ActionListener.<List<Model>>wrap(
1647+
updatedModels -> assertEquals(Collections.emptyList(), updatedModels),
1648+
e -> fail("Unexpected exception: " + e)
1649+
);
1650+
1651+
try (var service = createService(mock(Client.class))) {
1652+
service.updateModelsWithDynamicFields(List.of(), resultsListener);
1653+
}
1654+
}
1655+
1656+
public void testUpdateModelsWithDynamicFields_InvalidModelProvided() throws IOException {
1657+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(
1658+
updatedModels -> fail("Expected invalid model assertion error to be thrown"),
1659+
e -> fail("Expected invalid model assertion error to be thrown")
1660+
);
1661+
1662+
try (var service = createService(mock(Client.class))) {
1663+
assertThrows(
1664+
AssertionError.class,
1665+
() -> { service.updateModelsWithDynamicFields(List.of(mock(Model.class)), resultsListener); }
1666+
);
1667+
}
1668+
}
1669+
1670+
@SuppressWarnings("unchecked")
1671+
public void testUpdateModelsWithDynamicFields_FailsToRetrieveDeployments() throws IOException {
1672+
var deploymentId = randomAlphaOfLength(10);
1673+
var model = mock(ElasticsearchInternalModel.class);
1674+
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
1675+
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1676+
1677+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
1678+
assertEquals(updatedModels.size(), 1);
1679+
verify(model).mlNodeDeploymentId();
1680+
verifyNoMoreInteractions(model);
1681+
}, e -> fail("Expected original models to be returned"));
1682+
1683+
var client = mock(Client.class);
1684+
when(client.threadPool()).thenReturn(threadPool);
1685+
doAnswer(invocation -> {
1686+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
1687+
listener.onFailure(new RuntimeException(randomAlphaOfLength(10)));
1688+
return null;
1689+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
1690+
1691+
try (var service = createService(client)) {
1692+
service.updateModelsWithDynamicFields(List.of(model), resultsListener);
1693+
}
1694+
}
1695+
1696+
public void testUpdateModelsWithDynamicFields_SingleModelToUpdate() throws IOException {
1697+
var deploymentId = randomAlphaOfLength(10);
1698+
var model = mock(ElasticsearchInternalModel.class);
1699+
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
1700+
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1701+
1702+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1703+
modelsByDeploymentId.put(deploymentId, List.of(model));
1704+
1705+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1706+
}
1707+
1708+
public void testUpdateModelsWithDynamicFields_MultipleModelsWithDifferentDeploymentsToUpdate() throws IOException {
1709+
var deploymentId1 = randomAlphaOfLength(10);
1710+
var model1 = mock(ElasticsearchInternalModel.class);
1711+
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId1);
1712+
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1713+
var deploymentId2 = randomAlphaOfLength(10);
1714+
var model2 = mock(ElasticsearchInternalModel.class);
1715+
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId2);
1716+
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1717+
1718+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1719+
modelsByDeploymentId.put(deploymentId1, List.of(model1));
1720+
modelsByDeploymentId.put(deploymentId2, List.of(model2));
1721+
1722+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1723+
}
1724+
1725+
public void testUpdateModelsWithDynamicFields_MultipleModelsWithSameDeploymentsToUpdate() throws IOException {
1726+
var deploymentId = randomAlphaOfLength(10);
1727+
var model1 = mock(ElasticsearchInternalModel.class);
1728+
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId);
1729+
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1730+
var model2 = mock(ElasticsearchInternalModel.class);
1731+
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId);
1732+
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
1733+
1734+
var modelsByDeploymentId = new HashMap<String, List<Model>>();
1735+
modelsByDeploymentId.put(deploymentId, List.of(model1, model2));
1736+
1737+
testUpdateModelsWithDynamicFields(modelsByDeploymentId);
1738+
}
1739+
1740+
@SuppressWarnings("unchecked")
1741+
private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsByDeploymentId) throws IOException {
1742+
var modelsToUpdate = new ArrayList<Model>();
1743+
modelsByDeploymentId.values().forEach(modelsToUpdate::addAll);
1744+
1745+
var updatedNumberOfAllocations = new HashMap<String, Integer>();
1746+
modelsByDeploymentId.keySet().forEach(deploymentId -> updatedNumberOfAllocations.put(deploymentId, randomIntBetween(1, 10)));
1747+
1748+
ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
1749+
assertEquals(updatedModels.size(), modelsToUpdate.size());
1750+
modelsByDeploymentId.forEach((deploymentId, models) -> {
1751+
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
1752+
models.forEach(model -> {
1753+
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
1754+
verify((ElasticsearchInternalModel) model).mlNodeDeploymentId();
1755+
verifyNoMoreInteractions(model);
1756+
});
1757+
});
1758+
}, e -> fail("Unexpected exception: " + e));
1759+
1760+
var client = mock(Client.class);
1761+
when(client.threadPool()).thenReturn(threadPool);
1762+
doAnswer(invocation -> {
1763+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
1764+
var mockAssignmentStats = new ArrayList<AssignmentStats>();
1765+
modelsByDeploymentId.keySet().forEach(deploymentId -> {
1766+
var mockAssignmentStatsForDeploymentId = mock(AssignmentStats.class);
1767+
when(mockAssignmentStatsForDeploymentId.getDeploymentId()).thenReturn(deploymentId);
1768+
when(mockAssignmentStatsForDeploymentId.getNumberOfAllocations()).thenReturn(updatedNumberOfAllocations.get(deploymentId));
1769+
mockAssignmentStats.add(mockAssignmentStatsForDeploymentId);
1770+
});
1771+
listener.onResponse(
1772+
new GetDeploymentStatsAction.Response(
1773+
Collections.emptyList(),
1774+
Collections.emptyList(),
1775+
mockAssignmentStats,
1776+
mockAssignmentStats.size()
1777+
)
1778+
);
1779+
return null;
1780+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
1781+
1782+
try (var service = createService(client)) {
1783+
service.updateModelsWithDynamicFields(modelsToUpdate, resultsListener);
1784+
}
1785+
}
1786+
16431787
public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException {
16441788
var cs = mock(ClusterService.class);
16451789
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

0 commit comments

Comments
 (0)