Skip to content

Commit 8040fbb

Browse files
authored
[ML] Dynamically get of num allocations (#114636)
1 parent 01bfdf8 commit 8040fbb

File tree

7 files changed

+148
-24
lines changed

7 files changed

+148
-24
lines changed

docs/changelog/114636.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114636
2+
summary: Dynamically get of num allocations
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,8 @@ default List<DefaultConfigId> defaultConfigIds() {
210210
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
211211
defaultsListener.onResponse(List.of());
212212
}
213+
214+
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
215+
listener.onResponse(model);
216+
}
213217
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.action.ActionRunnable;
1312
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.GroupedActionListener;
1414
import org.elasticsearch.action.support.HandledTransportAction;
1515
import org.elasticsearch.common.Strings;
1616
import org.elasticsearch.common.util.concurrent.EsExecutors;
1717
import org.elasticsearch.inference.InferenceServiceRegistry;
18-
import org.elasticsearch.inference.ModelConfigurations;
18+
import org.elasticsearch.inference.Model;
1919
import org.elasticsearch.inference.TaskType;
2020
import org.elasticsearch.inference.UnparsedModel;
2121
import org.elasticsearch.injection.guice.Inject;
@@ -29,8 +29,11 @@
2929
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3030

3131
import java.util.ArrayList;
32+
import java.util.Comparator;
33+
import java.util.HashMap;
3234
import java.util.List;
3335
import java.util.concurrent.Executor;
36+
import java.util.stream.Collectors;
3437

3538
public class TransportGetInferenceModelAction extends HandledTransportAction<
3639
GetInferenceModelAction.Request,
@@ -96,38 +99,69 @@ private void getSingleModel(
9699

97100
var model = service.get()
98101
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
99-
delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations())));
102+
103+
service.get()
104+
.updateModelsWithDynamicFields(
105+
List.of(model),
106+
delegate.delegateFailureAndWrap(
107+
(l2, updatedModels) -> l2.onResponse(
108+
new GetInferenceModelAction.Response(
109+
updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())
110+
)
111+
)
112+
)
113+
);
100114
}));
101115
}
102116

103117
private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
104-
modelRegistry.getAllModels(
105-
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
106-
);
118+
modelRegistry.getAllModels(listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))));
107119
}
108120

109121
private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
110122
modelRegistry.getModelsByTaskType(
111123
taskType,
112-
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
124+
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
113125
);
114126
}
115127

116-
private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {
117-
var parsedModels = new ArrayList<ModelConfigurations>();
118-
119-
for (var unparsedModel : unparsedModels) {
120-
var service = serviceRegistry.getService(unparsedModel.service());
121-
if (service.isEmpty()) {
122-
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
128+
private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
129+
var parsedModelsByService = new HashMap<String, List<Model>>();
130+
try {
131+
for (var unparsedModel : unparsedModels) {
132+
var service = serviceRegistry.getService(unparsedModel.service());
133+
if (service.isEmpty()) {
134+
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
135+
}
136+
var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());
137+
list.add(
138+
service.get()
139+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
140+
);
123141
}
124-
parsedModels.add(
125-
service.get()
126-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
127-
.getConfigurations()
142+
143+
var groupedListener = new GroupedActionListener<List<Model>>(
144+
parsedModelsByService.entrySet().size(),
145+
listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {
146+
var modifiable = new ArrayList<Model>();
147+
for (var l : listOfListOfModels) {
148+
modifiable.addAll(l);
149+
}
150+
modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));
151+
delegate.onResponse(
152+
new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList()))
153+
);
154+
})
128155
);
156+
157+
for (var entry : parsedModelsByService.entrySet()) {
158+
serviceRegistry.getService(entry.getKey())
159+
.get() // must be non-null to get this far
160+
.updateModelsWithDynamicFields(entry.getValue(), groupedListener);
161+
}
162+
} catch (Exception e) {
163+
listener.onFailure(e);
129164
}
130-
return new GetInferenceModelAction.Response(parsedModels);
131165
}
132166

133167
private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public abstract class ElasticsearchInternalModel extends Model {
2323

24-
protected final ElasticsearchInternalServiceSettings internalServiceSettings;
24+
protected ElasticsearchInternalServiceSettings internalServiceSettings;
2525

2626
public ElasticsearchInternalModel(
2727
String inferenceEntityId,
@@ -87,6 +87,15 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
8787
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
8888
}
8989

90+
public void updateNumAllocation(Integer numAllocations) {
91+
this.internalServiceSettings = new ElasticsearchInternalServiceSettings(
92+
numAllocations,
93+
this.internalServiceSettings.getNumThreads(),
94+
this.internalServiceSettings.modelId(),
95+
this.internalServiceSettings.getAdaptiveAllocationsSettings()
96+
);
97+
}
98+
9099
@Override
91100
public String toString() {
92101
return Strings.toString(this.getConfigurations());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
3333
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3434
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
35+
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
3536
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
3637
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
3738
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
@@ -50,6 +51,7 @@
5051
import java.util.ArrayList;
5152
import java.util.Collections;
5253
import java.util.EnumSet;
54+
import java.util.HashMap;
5355
import java.util.List;
5456
import java.util.Map;
5557
import java.util.Set;
@@ -801,11 +803,47 @@ public List<DefaultConfigId> defaultConfigIds() {
801803
return List.of(new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this));
802804
}
803805

804-
/**
805-
* Default configurations that can be out of the box without creating an endpoint first.
806-
* @param defaultsListener Config listener
807-
*/
808806
@Override
807+
public void updateModelsWithDynamicFields(List<Model> models, ActionListener<List<Model>> listener) {
808+
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
809+
for (var model : models) {
810+
if (model instanceof ElasticsearchInternalModel esModel) {
811+
modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel);
812+
} else {
813+
listener.onFailure(
814+
new ElasticsearchStatusException(
815+
"Cannot update model [{}] as it is not an Elasticsearch service model",
816+
RestStatus.INTERNAL_SERVER_ERROR,
817+
model.getInferenceEntityId()
818+
)
819+
);
820+
return;
821+
}
822+
}
823+
824+
if (modelsByDeploymentIds.isEmpty()) {
825+
listener.onResponse(models);
826+
return;
827+
}
828+
829+
String deploymentIds = String.join(",", modelsByDeploymentIds.keySet());
830+
client.execute(
831+
GetDeploymentStatsAction.INSTANCE,
832+
new GetDeploymentStatsAction.Request(deploymentIds),
833+
ActionListener.wrap(stats -> {
834+
for (var deploymentStats : stats.getStats().results()) {
835+
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
836+
model.updateNumAllocation(deploymentStats.getNumberOfAllocations());
837+
}
838+
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
839+
}, e -> {
840+
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
841+
// continue with the original response
842+
listener.onResponse(models);
843+
})
844+
);
845+
}
846+
809847
public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
810848
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
811849
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ public String modelId() {
166166
return modelId;
167167
}
168168

169+
public String deloymentId() {
170+
return modelId;
171+
}
172+
169173
public Integer getNumAllocations() {
170174
return numAllocations;
171175
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.test.ESTestCase;
12+
13+
public class ElserInternalModelTests extends ESTestCase {
14+
public void testUpdateNumAllocation() {
15+
var model = new ElserInternalModel(
16+
"foo",
17+
TaskType.SPARSE_EMBEDDING,
18+
ElasticsearchInternalService.NAME,
19+
new ElserInternalServiceSettings(null, 1, "elser", null),
20+
new ElserMlNodeTaskSettings(),
21+
null
22+
);
23+
24+
model.updateNumAllocation(1);
25+
assertEquals(1, model.internalServiceSettings.getNumAllocations().intValue());
26+
27+
model.updateNumAllocation(null);
28+
assertNull(model.internalServiceSettings.getNumAllocations());
29+
}
30+
}

0 commit comments

Comments
 (0)