Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/121821.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 121821
summary: Fix get all inference endponts not returning multiple endpoints sharing model
deployment
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -854,12 +854,20 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
return;
}

var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
var modelsByDeploymentIds = new HashMap<String, List<ElasticsearchInternalModel>>();
for (var model : models) {
assert model instanceof ElasticsearchInternalModel;

if (model instanceof ElasticsearchInternalModel esModel) {
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
if (modelsByDeploymentIds.containsKey(esModel.mlNodeDeploymentId()) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the if-else can be distilled to something like this:

modelsByDeploymentIds.merge(
  esModel.mlNodeDeploymentId(),
  new ArrayList<String>(List.of(esModel)), (a, b) -> {
    a.addAll(b);
    return a;
});

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like this method much better! I'll go ahead and make that change.

modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), new ArrayList<>() {
{
add(esModel);
}
});
} else {
modelsByDeploymentIds.get(esModel.mlNodeDeploymentId()).add(esModel);
}
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -878,10 +886,13 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
new GetDeploymentStatsAction.Request(deploymentIds),
ActionListener.wrap(stats -> {
for (var deploymentStats : stats.getStats().results()) {
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
}
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
var updatedModels = new ArrayList<Model>();
modelsByDeploymentIds.values().forEach(updatedModels::addAll);

listener.onResponse(updatedModels);
}, e -> {
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
// continue with the original response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -108,7 +109,9 @@
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class ElasticsearchInternalServiceTests extends ESTestCase {
Expand Down Expand Up @@ -1640,6 +1643,148 @@ public void testGetConfiguration() throws Exception {
}
}

public void testUpdateModelsWithDynamicFields_NoModelsToUpdate() throws Exception {
ActionListener<List<Model>> resultsListener = ActionListener.<List<Model>>wrap(
updatedModels -> assertEquals(Collections.emptyList(), updatedModels),
e -> fail("Unexpected exception: " + e)
);

try (var service = createService(mock(Client.class))) {
service.updateModelsWithDynamicFields(List.of(), resultsListener);
}
}

public void testUpdateModelsWithDynamicFields_InvalidModelProvided() throws IOException {
ActionListener<List<Model>> resultsListener = ActionListener.wrap(
updatedModels -> fail("Expected invalid model assertion error to be thrown"),
e -> fail("Expected invalid model assertion error to be thrown")
);

try (var service = createService(mock(Client.class))) {
assertThrows(
AssertionError.class,
() -> { service.updateModelsWithDynamicFields(List.of(mock(Model.class)), resultsListener); }
);
}
}

@SuppressWarnings("unchecked")
public void testUpdateModelsWithDynamicFields_FailsToRetrieveDeployments() throws IOException {
var deploymentId = randomAlphaOfLength(10);
var model = mock(ElasticsearchInternalModel.class);
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
assertEquals(updatedModels.size(), 1);
verify(model, times(2)).mlNodeDeploymentId();
verifyNoMoreInteractions(model);
}, e -> fail("Expected original models to be returned"));

var client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocation -> {
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
listener.onFailure(new RuntimeException(randomAlphaOfLength(10)));
return null;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());

try (var service = createService(client)) {
service.updateModelsWithDynamicFields(List.of(model), resultsListener);
}
}

public void testUpdateModelsWithDynamicFields_SingleModelToUpdate() throws IOException {
var deploymentId = randomAlphaOfLength(10);
var model = mock(ElasticsearchInternalModel.class);
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

var modelsByDeploymentId = new HashMap<String, List<Model>>();
modelsByDeploymentId.put(deploymentId, List.of(model));

testUpdateModelsWithDynamicFields(modelsByDeploymentId);
}

public void testUpdateModelsWithDynamicFields_MultipleModelsWithDifferentDeploymentsToUpdate() throws IOException {
var deploymentId1 = randomAlphaOfLength(10);
var model1 = mock(ElasticsearchInternalModel.class);
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId1);
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
var deploymentId2 = randomAlphaOfLength(10);
var model2 = mock(ElasticsearchInternalModel.class);
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId2);
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

var modelsByDeploymentId = new HashMap<String, List<Model>>();
modelsByDeploymentId.put(deploymentId1, List.of(model1));
modelsByDeploymentId.put(deploymentId2, List.of(model2));

testUpdateModelsWithDynamicFields(modelsByDeploymentId);
}

public void testUpdateModelsWithDynamicFields_MultipleModelsWithSameDeploymentsToUpdate() throws IOException {
var deploymentId = randomAlphaOfLength(10);
var model1 = mock(ElasticsearchInternalModel.class);
when(model1.mlNodeDeploymentId()).thenReturn(deploymentId);
when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
var model2 = mock(ElasticsearchInternalModel.class);
when(model2.mlNodeDeploymentId()).thenReturn(deploymentId);
when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

var modelsByDeploymentId = new HashMap<String, List<Model>>();
modelsByDeploymentId.put(deploymentId, List.of(model1, model2));

testUpdateModelsWithDynamicFields(modelsByDeploymentId);
}

@SuppressWarnings("unchecked")
private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsByDeploymentId) throws IOException {
var modelsToUpdate = new ArrayList<Model>();
modelsByDeploymentId.values().forEach(modelsToUpdate::addAll);

var updatedNumberOfAllocations = new HashMap<String, Integer>();
modelsByDeploymentId.keySet().forEach(deploymentId -> updatedNumberOfAllocations.put(deploymentId, randomIntBetween(1, 10)));

ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> {
assertEquals(updatedModels.size(), modelsToUpdate.size());
modelsByDeploymentId.forEach((deploymentId, models) -> {
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
models.forEach(model -> {
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
verify((ElasticsearchInternalModel) model, times(2)).mlNodeDeploymentId();
verifyNoMoreInteractions(model);
});
});
}, e -> fail("Unexpected exception: " + e));

var client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocation -> {
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2];
var mockAssignmentStats = new ArrayList<AssignmentStats>();
modelsByDeploymentId.keySet().forEach(deploymentId -> {
var mockAssignmentStatsForDeploymentId = mock(AssignmentStats.class);
when(mockAssignmentStatsForDeploymentId.getDeploymentId()).thenReturn(deploymentId);
when(mockAssignmentStatsForDeploymentId.getNumberOfAllocations()).thenReturn(updatedNumberOfAllocations.get(deploymentId));
mockAssignmentStats.add(mockAssignmentStatsForDeploymentId);
});
listener.onResponse(
new GetDeploymentStatsAction.Response(
Collections.emptyList(),
Collections.emptyList(),
mockAssignmentStats,
mockAssignmentStats.size()
)
);
return null;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());

try (var service = createService(client)) {
service.updateModelsWithDynamicFields(modelsToUpdate, resultsListener);
}
}

public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException {
var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
Expand Down