Skip to content

Commit c7f53ff

Browse files
authored
[ML] Dynamically get of num allocations for ml node models (elastic#115233)
The GET inference API which should dynamically update the num_allocations field with the actual number from the deployed model which is useful when adaptive allocations are used
1 parent f32051f commit c7f53ff

File tree

8 files changed

+219
-22
lines changed

8 files changed

+219
-22
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,8 @@ default List<DefaultConfigId> defaultConfigIds() {
210210
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
211211
defaultsListener.onResponse(List.of());
212212
}
213+
214+
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
215+
listener.onResponse(model);
216+
}
213217
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException {
109109
);
110110
}
111111

112+
public void testNumAllocationsIsUpdated() throws IOException {
113+
var modelId = "update_num_allocations";
114+
var deploymentId = modelId;
115+
116+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
117+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
118+
assertOkOrCreated(response);
119+
120+
var inferenceId = "test_num_allocations_updated";
121+
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
122+
var serviceSettings = putModel.get("service_settings");
123+
assertThat(
124+
putModel.toString(),
125+
serviceSettings,
126+
is(
127+
Map.of(
128+
"num_allocations",
129+
1,
130+
"num_threads",
131+
1,
132+
"model_id",
133+
"update_num_allocations",
134+
"deployment_id",
135+
"update_num_allocations"
136+
)
137+
)
138+
);
139+
140+
assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
141+
142+
var updatedServiceSettings = getModel(inferenceId).get("service_settings");
143+
assertThat(
144+
updatedServiceSettings.toString(),
145+
updatedServiceSettings,
146+
is(
147+
Map.of(
148+
"num_allocations",
149+
2,
150+
"num_threads",
151+
1,
152+
"model_id",
153+
"update_num_allocations",
154+
"deployment_id",
155+
"update_num_allocations"
156+
)
157+
)
158+
);
159+
}
160+
112161
private String endpointConfig(String deploymentId) {
113162
return Strings.format("""
114163
{
@@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
147196
return client().performRequest(request);
148197
}
149198

199+
private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
200+
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";
201+
202+
var body = Strings.format("""
203+
{
204+
"number_of_allocations": %d
205+
}
206+
""", numAllocations);
207+
208+
Request request = new Request("POST", endPoint);
209+
request.setJsonEntity(body);
210+
return client().performRequest(request);
211+
}
212+
150213
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
151214
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
152215
Request request = new Request("POST", endpoint);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.stream.Stream;
2525

2626
import static org.hamcrest.Matchers.containsString;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.equalTo;
2829
import static org.hamcrest.Matchers.equalToIgnoringCase;
2930
import static org.hamcrest.Matchers.hasSize;
@@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception {
326327
deleteModel(modelId);
327328
}
328329
}
330+
331+
public void testGetZeroModels() throws IOException {
332+
var models = getModels("_all", TaskType.RERANK);
333+
assertThat(models, empty());
334+
}
329335
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.action.ActionRunnable;
1312
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.GroupedActionListener;
1414
import org.elasticsearch.action.support.HandledTransportAction;
1515
import org.elasticsearch.common.Strings;
1616
import org.elasticsearch.common.util.concurrent.EsExecutors;
1717
import org.elasticsearch.inference.InferenceServiceRegistry;
18-
import org.elasticsearch.inference.ModelConfigurations;
18+
import org.elasticsearch.inference.Model;
1919
import org.elasticsearch.inference.TaskType;
2020
import org.elasticsearch.inference.UnparsedModel;
2121
import org.elasticsearch.injection.guice.Inject;
@@ -29,8 +29,11 @@
2929
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3030

3131
import java.util.ArrayList;
32+
import java.util.Comparator;
33+
import java.util.HashMap;
3234
import java.util.List;
3335
import java.util.concurrent.Executor;
36+
import java.util.stream.Collectors;
3437

3538
public class TransportGetInferenceModelAction extends HandledTransportAction<
3639
GetInferenceModelAction.Request,
@@ -96,39 +99,77 @@ private void getSingleModel(
9699

97100
var model = service.get()
98101
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
99-
delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations())));
102+
103+
service.get()
104+
.updateModelsWithDynamicFields(
105+
List.of(model),
106+
delegate.delegateFailureAndWrap(
107+
(l2, updatedModels) -> l2.onResponse(
108+
new GetInferenceModelAction.Response(
109+
updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())
110+
)
111+
)
112+
)
113+
);
100114
}));
101115
}
102116

103117
private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
104118
modelRegistry.getAllModels(
105119
persistDefaultEndpoints,
106-
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
120+
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
107121
);
108122
}
109123

110124
private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
111125
modelRegistry.getModelsByTaskType(
112126
taskType,
113-
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
127+
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
114128
);
115129
}
116130

117-
private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {
118-
var parsedModels = new ArrayList<ModelConfigurations>();
131+
private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
132+
if (unparsedModels.isEmpty()) {
133+
listener.onResponse(new GetInferenceModelAction.Response(List.of()));
134+
return;
135+
}
119136

120-
for (var unparsedModel : unparsedModels) {
121-
var service = serviceRegistry.getService(unparsedModel.service());
122-
if (service.isEmpty()) {
123-
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
137+
var parsedModelsByService = new HashMap<String, List<Model>>();
138+
try {
139+
for (var unparsedModel : unparsedModels) {
140+
var service = serviceRegistry.getService(unparsedModel.service());
141+
if (service.isEmpty()) {
142+
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
143+
}
144+
var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());
145+
list.add(
146+
service.get()
147+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
148+
);
124149
}
125-
parsedModels.add(
126-
service.get()
127-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
128-
.getConfigurations()
150+
151+
var groupedListener = new GroupedActionListener<List<Model>>(
152+
parsedModelsByService.entrySet().size(),
153+
listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {
154+
var modifiable = new ArrayList<Model>();
155+
for (var l : listOfListOfModels) {
156+
modifiable.addAll(l);
157+
}
158+
modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));
159+
delegate.onResponse(
160+
new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList()))
161+
);
162+
})
129163
);
164+
165+
for (var entry : parsedModelsByService.entrySet()) {
166+
serviceRegistry.getService(entry.getKey())
167+
.get() // must be non-null to get this far
168+
.updateModelsWithDynamicFields(entry.getValue(), groupedListener);
169+
}
170+
} catch (Exception e) {
171+
listener.onFailure(e);
130172
}
131-
return new GetInferenceModelAction.Response(parsedModels);
132173
}
133174

134175
private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public abstract class ElasticsearchInternalModel extends Model {
2323

24-
protected final ElasticsearchInternalServiceSettings internalServiceSettings;
24+
protected ElasticsearchInternalServiceSettings internalServiceSettings;
2525

2626
public ElasticsearchInternalModel(
2727
String inferenceEntityId,
@@ -91,6 +91,10 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
9191
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
9292
}
9393

94+
public void updateNumAllocations(Integer numAllocations) {
95+
this.internalServiceSettings.setNumAllocations(numAllocations);
96+
}
97+
9498
@Override
9599
public String toString() {
96100
return Strings.toString(this.getConfigurations());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
3333
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3434
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
35+
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
3536
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
3637
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
3738
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
@@ -56,6 +57,7 @@
5657
import java.util.ArrayList;
5758
import java.util.Collections;
5859
import java.util.EnumSet;
60+
import java.util.HashMap;
5961
import java.util.List;
6062
import java.util.Map;
6163
import java.util.Optional;
@@ -786,11 +788,50 @@ public List<DefaultConfigId> defaultConfigIds() {
786788
);
787789
}
788790

789-
/**
790-
* Default configurations that can be out of the box without creating an endpoint first.
791-
* @param defaultsListener Config listener
792-
*/
793791
@Override
792+
public void updateModelsWithDynamicFields(List<Model> models, ActionListener<List<Model>> listener) {
793+
794+
if (models.isEmpty()) {
795+
listener.onResponse(models);
796+
return;
797+
}
798+
799+
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
800+
for (var model : models) {
801+
assert model instanceof ElasticsearchInternalModel;
802+
803+
if (model instanceof ElasticsearchInternalModel esModel) {
804+
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
805+
} else {
806+
listener.onFailure(
807+
new ElasticsearchStatusException(
808+
"Cannot update model [{}] as it is not an Elasticsearch service model",
809+
RestStatus.INTERNAL_SERVER_ERROR,
810+
model.getInferenceEntityId()
811+
)
812+
);
813+
return;
814+
}
815+
}
816+
817+
String deploymentIds = String.join(",", modelsByDeploymentIds.keySet());
818+
client.execute(
819+
GetDeploymentStatsAction.INSTANCE,
820+
new GetDeploymentStatsAction.Request(deploymentIds),
821+
ActionListener.wrap(stats -> {
822+
for (var deploymentStats : stats.getStats().results()) {
823+
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
824+
model.updateNumAllocations(deploymentStats.getNumberOfAllocations());
825+
}
826+
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
827+
}, e -> {
828+
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
829+
// continue with the original response
830+
listener.onResponse(models);
831+
})
832+
);
833+
}
834+
794835
public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
795836
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
796837
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
3939
public static final String DEPLOYMENT_ID = "deployment_id";
4040
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";
4141

42-
private final Integer numAllocations;
42+
private Integer numAllocations;
4343
private final int numThreads;
4444
private final String modelId;
4545
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
@@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
172172
: null;
173173
}
174174

175+
public void setNumAllocations(Integer numAllocations) {
176+
this.numAllocations = numAllocations;
177+
}
178+
175179
@Override
176180
public void writeTo(StreamOutput out) throws IOException {
177181
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
@@ -194,6 +198,10 @@ public String modelId() {
194198
return modelId;
195199
}
196200

201+
public String deloymentId() {
202+
return modelId;
203+
}
204+
197205
public Integer getNumAllocations() {
198206
return numAllocations;
199207
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.test.ESTestCase;
12+
13+
public class ElserInternalModelTests extends ESTestCase {
14+
public void testUpdateNumAllocation() {
15+
var model = new ElserInternalModel(
16+
"foo",
17+
TaskType.SPARSE_EMBEDDING,
18+
ElasticsearchInternalService.NAME,
19+
new ElserInternalServiceSettings(null, 1, "elser", null),
20+
new ElserMlNodeTaskSettings(),
21+
null
22+
);
23+
24+
model.updateNumAllocations(1);
25+
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
26+
27+
model.updateNumAllocations(null);
28+
assertNull(model.getServiceSettings().getNumAllocations());
29+
}
30+
}

0 commit comments

Comments
 (0)