Skip to content

Commit 67dad43

Browse files
Enable updating adaptive_allocations for ElasticsearchInternalService
1 parent 1dfb70e commit 67dad43

File tree

3 files changed

+208
-6
lines changed

3 files changed

+208
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
12+
13+
import java.io.IOException;
14+
import java.util.List;
15+
import java.util.Map;
16+
import java.util.Optional;
17+
18+
public class InferenceUpdateElasticsearchInternalServiceModelIT extends CustomElandModelIT {
19+
private final List<AdaptiveAllocationsSettings> ADAPTIVE_ALLOCATIONS_SETTINGS = List.of(
20+
new AdaptiveAllocationsSettings(randomBoolean(), null, null),
21+
new AdaptiveAllocationsSettings(null, randomIntBetween(1, 10), null),
22+
new AdaptiveAllocationsSettings(null, null, randomIntBetween(1, 10)),
23+
new AdaptiveAllocationsSettings(randomBoolean(), randomIntBetween(1, 10), randomIntBetween(11, 20))
24+
);
25+
26+
public void testUpdateNumThreads() throws IOException {
27+
testUpdateElasticsearchInternalServiceEndpoint(
28+
Optional.of(randomIntBetween(2, 10)),
29+
Optional.empty(),
30+
Optional.empty(),
31+
Optional.empty()
32+
);
33+
}
34+
35+
public void testUpdateAdaptiveAllocationsSettings() throws IOException {
36+
for (AdaptiveAllocationsSettings settings : ADAPTIVE_ALLOCATIONS_SETTINGS) {
37+
testUpdateElasticsearchInternalServiceEndpoint(
38+
Optional.empty(),
39+
Optional.ofNullable(settings.getEnabled()),
40+
Optional.ofNullable(settings.getMinNumberOfAllocations()),
41+
Optional.ofNullable(settings.getMaxNumberOfAllocations())
42+
);
43+
}
44+
}
45+
46+
public void testUpdateNumAllocationsAndAdaptiveAllocationsSettings() throws IOException {
47+
testUpdateElasticsearchInternalServiceEndpoint(
48+
Optional.of(randomIntBetween(2, 10)),
49+
Optional.of(randomBoolean()),
50+
Optional.of(randomIntBetween(1, 10)),
51+
Optional.of(randomIntBetween(11, 20))
52+
);
53+
}
54+
55+
private void testUpdateElasticsearchInternalServiceEndpoint(
56+
Optional<Integer> updatedNumAllocations,
57+
Optional<Boolean> updatedAdaptiveAllocationsEnabled,
58+
Optional<Integer> updatedMinNumberOfAllocations,
59+
Optional<Integer> updatedMaxNumberOfAllocations
60+
) throws IOException {
61+
var inferenceId = "update-adaptive-allocations-inference";
62+
var originalEndpoint = setupInferenceEndpoint(inferenceId);
63+
verifyEndpointConfig(originalEndpoint, 1, Optional.empty(), Optional.empty(), Optional.empty());
64+
65+
var updateConfig = generateUpdateConfig(
66+
updatedNumAllocations,
67+
updatedAdaptiveAllocationsEnabled,
68+
updatedMinNumberOfAllocations,
69+
updatedMaxNumberOfAllocations
70+
);
71+
var updatedEndpoint = updateEndpoint(inferenceId, updateConfig, TaskType.SPARSE_EMBEDDING);
72+
verifyEndpointConfig(
73+
updatedEndpoint,
74+
updatedNumAllocations.orElse(1),
75+
updatedAdaptiveAllocationsEnabled,
76+
updatedMinNumberOfAllocations,
77+
updatedMaxNumberOfAllocations
78+
);
79+
}
80+
81+
private Map<String, Object> setupInferenceEndpoint(String inferenceId) throws IOException {
82+
String modelId = "custom-text-expansion-model";
83+
createMlNodeTextExpansionModel(modelId, client());
84+
85+
var inferenceConfig = """
86+
{
87+
"service": "elasticsearch",
88+
"service_settings": {
89+
"model_id": "custom-text-expansion-model",
90+
"num_allocations": 1,
91+
"num_threads": 1
92+
}
93+
}
94+
""";
95+
96+
return putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
97+
}
98+
99+
public static String generateUpdateConfig(
100+
Optional<Integer> numAllocations,
101+
Optional<Boolean> adaptiveAllocationsEnabled,
102+
Optional<Integer> minNumberOfAllocations,
103+
Optional<Integer> maxNumberOfAllocations
104+
) {
105+
StringBuilder requestBodyBuilder = new StringBuilder();
106+
requestBodyBuilder.append("{ \"service_settings\": {");
107+
108+
numAllocations.ifPresent(value -> requestBodyBuilder.append("\"num_allocations\": ").append(value).append(","));
109+
110+
if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) {
111+
requestBodyBuilder.append("\"adaptive_allocations\": {");
112+
adaptiveAllocationsEnabled.ifPresent(value -> requestBodyBuilder.append("\"enabled\": ").append(value).append(","));
113+
minNumberOfAllocations.ifPresent(
114+
value -> requestBodyBuilder.append("\"min_number_of_allocations\": ").append(value).append(",")
115+
);
116+
maxNumberOfAllocations.ifPresent(
117+
value -> requestBodyBuilder.append("\"max_number_of_allocations\": ").append(value).append(",")
118+
);
119+
120+
if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') {
121+
requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1);
122+
}
123+
requestBodyBuilder.append("},");
124+
}
125+
126+
if (requestBodyBuilder.charAt(requestBodyBuilder.length() - 1) == ',') {
127+
requestBodyBuilder.deleteCharAt(requestBodyBuilder.length() - 1);
128+
}
129+
130+
requestBodyBuilder.append("} }");
131+
return requestBodyBuilder.toString();
132+
}
133+
134+
@SuppressWarnings("unchecked")
135+
private void verifyEndpointConfig(
136+
Map<String, Object> endpointConfig,
137+
int expectedNumAllocations,
138+
Optional<Boolean> adaptiveAllocationsEnabled,
139+
Optional<Integer> minNumberOfAllocations,
140+
Optional<Integer> maxNumberOfAllocations
141+
) {
142+
var serviceSettings = (Map<String, Object>) endpointConfig.get("service_settings");
143+
144+
assertEquals(expectedNumAllocations, serviceSettings.get("num_allocations"));
145+
if (adaptiveAllocationsEnabled.isPresent() || minNumberOfAllocations.isPresent() || maxNumberOfAllocations.isPresent()) {
146+
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
147+
adaptiveAllocationsEnabled.ifPresent(enabled -> assertEquals(enabled, adaptiveAllocations.get("enabled")));
148+
minNumberOfAllocations.ifPresent(min -> assertEquals(min, adaptiveAllocations.get("min_number_of_allocations")));
149+
maxNumberOfAllocations.ifPresent(max -> assertEquals(max, adaptiveAllocations.get("max_number_of_allocations")));
150+
}
151+
}
152+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
4747
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
4848
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
49+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
4950
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
5051
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
5152
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -63,6 +64,7 @@
6364

6465
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
6566
import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType;
67+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS;
6668
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS;
6769

6870
public class TransportUpdateInferenceModelAction extends TransportMasterNodeAction<
@@ -220,12 +222,17 @@ private Model combineExistingModelWithNewSettings(
220222
if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
221223
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
222224
}
223-
if (settingsToUpdate.serviceSettings() != null && settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)) {
225+
if (settingsToUpdate.serviceSettings() != null
226+
&& (settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)
227+
|| settingsToUpdate.serviceSettings().containsKey(ADAPTIVE_ALLOCATIONS))) {
224228
// In cluster services can only have their num_allocations updated, so this is a special case
225229
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
226230
newServiceSettings = new ElasticsearchInternalServiceSettings(
227231
elasticServiceSettings,
228-
(Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
232+
settingsToUpdate.serviceSettings().containsKey(NUM_ALLOCATIONS)
233+
? (Integer) settingsToUpdate.serviceSettings().get(NUM_ALLOCATIONS)
234+
: null,
235+
getAdaptiveAllocationsSettingsFromMap(settingsToUpdate.serviceSettings())
229236
);
230237
}
231238
}
@@ -259,10 +266,15 @@ private void updateInClusterEndpoint(
259266
throwIfTrainedModelDoesntExist(request.getInferenceEntityId(), deploymentId);
260267

261268
Map<String, Object> serviceSettings = request.getContentAsSettings().serviceSettings();
262-
if (serviceSettings != null && serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer numAllocations) {
269+
if (serviceSettings != null
270+
&& (serviceSettings.get(NUM_ALLOCATIONS) instanceof Integer || serviceSettings.containsKey(ADAPTIVE_ALLOCATIONS))) {
271+
var numAllocations = (Integer) serviceSettings.get(NUM_ALLOCATIONS);
272+
var adaptiveAllocationsSettings = getAdaptiveAllocationsSettingsFromMap(serviceSettings);
273+
// TODO: Figure out how to deep clonse the adaptive allocations settings as they are already removed at this point.
263274

264275
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
265276
updateRequest.setNumberOfAllocations(numAllocations);
277+
updateRequest.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
266278

267279
var delegate = listener.<CreateTrainedModelAssignmentAction.Response>delegateFailure((l2, response) -> {
268280
modelRegistry.updateModelTransaction(newModel, existingParsedModel, l2);
@@ -339,6 +351,36 @@ private void checkEndpointExists(String inferenceEntityId, ActionListener<Unpars
339351
}));
340352
}
341353

354+
@SuppressWarnings("unchecked")
355+
private AdaptiveAllocationsSettings getAdaptiveAllocationsSettingsFromMap(Map<String, Object> settings) {
356+
if (settings == null || settings.isEmpty() || settings.containsKey(ADAPTIVE_ALLOCATIONS) == false) {
357+
return null;
358+
}
359+
360+
var adaptiveAllocationsSettingsMap = (Map<String, Object>) settings.get(ADAPTIVE_ALLOCATIONS);
361+
362+
// TODO: Test invalid type being passed here. Also test if updating causes any issues with the UI
363+
var adaptiveAllocationsSettingsBuilder = new AdaptiveAllocationsSettings.Builder();
364+
adaptiveAllocationsSettingsBuilder.setEnabled(
365+
(Boolean) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.ENABLED.getPreferredName())
366+
);
367+
adaptiveAllocationsSettingsBuilder.setMinNumberOfAllocations(
368+
(Integer) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS.getPreferredName())
369+
);
370+
adaptiveAllocationsSettingsBuilder.setMaxNumberOfAllocations(
371+
(Integer) adaptiveAllocationsSettingsMap.get(AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS.getPreferredName())
372+
);
373+
374+
var adaptiveAllocationsSettings = adaptiveAllocationsSettingsBuilder.build();
375+
var validationException = adaptiveAllocationsSettings.validate();
376+
377+
if (validationException != null) {
378+
throw validationException;
379+
}
380+
381+
return adaptiveAllocationsSettings;
382+
}
383+
342384
private static XContentParser getParser(UpdateInferenceModelAction.Request request) throws IOException {
343385
return XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType());
344386
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,20 @@ protected ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSetti
135135
* Copy constructor with the ability to set the number of allocations. Used for Update API.
136136
* @param other the existing settings
137137
* @param numAllocations the new number of allocations
138+
* @param adaptiveAllocationsSettings the new adaptive allocations settings
138139
*/
139-
public ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings other, int numAllocations) {
140-
this.numAllocations = numAllocations;
140+
public ElasticsearchInternalServiceSettings(
141+
ElasticsearchInternalServiceSettings other,
142+
Integer numAllocations,
143+
AdaptiveAllocationsSettings adaptiveAllocationsSettings
144+
) {
145+
this.numAllocations = numAllocations == null ? other.numAllocations : numAllocations;
146+
// TODO: Should we block numAllocations<minNumOfAllocations. Also does this get updated by adaptive allocations?
141147
this.numThreads = other.numThreads;
142148
this.modelId = other.modelId;
143-
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings;
149+
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings == null
150+
? adaptiveAllocationsSettings
151+
: other.adaptiveAllocationsSettings.merge(adaptiveAllocationsSettings);
144152
this.deploymentId = other.deploymentId;
145153
}
146154

0 commit comments

Comments
 (0)