Skip to content

Commit bd6eeca

Browse files
authored
[ML] Wait for allocation on scale up from 0 (elastic#114719)
1 parent 8240945 commit bd6eeca

File tree

16 files changed

+707
-77
lines changed

16 files changed

+707
-77
lines changed

docs/changelog/114719.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114719
2+
summary: Wait for allocation on scale up
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ public enum FeatureFlag {
2020
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
2121
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
2222
CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
23-
INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
23+
INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
24+
ML_SCALE_FROM_ZERO("es.ml_scale_from_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
2425

2526
public final String systemProperty;
2627
public final Version from;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,12 @@ public boolean hasStartedRoutes() {
224224
return nodeRoutingTable.values().stream().anyMatch(routeInfo -> routeInfo.getState() == RoutingState.STARTED);
225225
}
226226

227-
public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(
228-
int numberOfRequests,
229-
RoutingState requiredState
230-
) {
227+
public List<Tuple<String, Integer>> selectRandomNodesWeighedOnAllocations(int numberOfRequests, RoutingState... acceptableStates) {
231228
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
232229
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
233230
int allocationSum = 0;
234231
for (Map.Entry<String, RoutingInfo> routingEntry : nodeRoutingTable.entrySet()) {
235-
if (routingEntry.getValue().getState() == requiredState) {
232+
if (routingEntry.getValue().getState().isAnyOf(acceptableStates)) {
236233
nodeIds.add(routingEntry.getKey());
237234
allocationSum += routingEntry.getValue().getCurrentAllocations();
238235
cumulativeAllocations.add(allocationSum);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,15 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS
195195
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, ""));
196196
TrainedModelAssignment assignment = builder.build();
197197

198-
assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED).isEmpty(), is(true));
198+
assertThat(assignment.selectRandomNodesWeighedOnAllocations(1, RoutingState.STARTED).isEmpty(), is(true));
199199
}
200200

201201
public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() {
202202
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null);
203203
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
204204
TrainedModelAssignment assignment = builder.build();
205205

206-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);
206+
var nodes = assignment.selectRandomNodesWeighedOnAllocations(1, RoutingState.STARTED);
207207

208208
assertThat(nodes, contains(new Tuple<>("node-1", 1)));
209209
}
@@ -213,7 +213,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh
213213
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
214214
TrainedModelAssignment assignment = builder.build();
215215

216-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);
216+
var nodes = assignment.selectRandomNodesWeighedOnAllocations(1, RoutingState.STOPPING);
217217

218218
assertThat(nodes, empty());
219219
}
@@ -223,7 +223,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh
223223
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, ""));
224224
TrainedModelAssignment assignment = builder.build();
225225

226-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);
226+
var nodes = assignment.selectRandomNodesWeighedOnAllocations(1, RoutingState.STOPPING);
227227

228228
assertThat(nodes, contains(new Tuple<>("node-1", 1)));
229229
}
@@ -234,7 +234,7 @@ public void testSingleRequestWith2Nodes() {
234234
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
235235
TrainedModelAssignment assignment = builder.build();
236236

237-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);
237+
var nodes = assignment.selectRandomNodesWeighedOnAllocations(1, RoutingState.STARTED);
238238
assertThat(nodes, hasSize(1));
239239
assertEquals(nodes.get(0).v2(), Integer.valueOf(1));
240240
}
@@ -248,7 +248,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul
248248

249249
final int selectionCount = 10000;
250250
final CountAccumulator countsPerNodeAccumulator = new CountAccumulator();
251-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);
251+
var nodes = assignment.selectRandomNodesWeighedOnAllocations(selectionCount, RoutingState.STARTED);
252252

253253
assertThat(nodes, hasSize(3));
254254
assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount));
@@ -269,7 +269,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul
269269
builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
270270
TrainedModelAssignment assignment = builder.build();
271271
final int selectionCount = 1000;
272-
var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);
272+
var nodeCounts = assignment.selectRandomNodesWeighedOnAllocations(selectionCount, RoutingState.STARTED);
273273
assertThat(nodeCounts, hasSize(3));
274274

275275
var selectedNodes = new HashSet<String>();
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.integration;
9+
10+
import org.apache.lucene.tests.util.LuceneTestCase;
11+
import org.elasticsearch.client.Response;
12+
import org.elasticsearch.client.ResponseListener;
13+
import org.elasticsearch.common.xcontent.support.XContentMapValues;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
16+
17+
import java.util.Arrays;
18+
import java.util.List;
19+
import java.util.Map;
20+
import java.util.concurrent.ConcurrentLinkedDeque;
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.TimeUnit;
23+
24+
import static org.hamcrest.Matchers.empty;
25+
import static org.hamcrest.Matchers.hasSize;
26+
import static org.hamcrest.Matchers.is;
27+
import static org.hamcrest.Matchers.not;
28+
import static org.hamcrest.Matchers.nullValue;
29+
30+
@LuceneTestCase.AwaitsFix(bugUrl = "Cannot test without setting the scale to zero period to a small value")
31+
public class AdaptiveAllocationsScaleFromZeroIT extends PyTorchModelRestTestCase {
32+
33+
@SuppressWarnings("unchecked")
34+
public void testScaleFromZero() throws Exception {
35+
String modelId = "test_scale_from_zero";
36+
createPassThroughModel(modelId);
37+
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
38+
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
39+
40+
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
41+
{
42+
var responseMap = entityAsMap(getTrainedModelStats(modelId));
43+
List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
44+
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
45+
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
46+
Integer count = (Integer) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
47+
assertThat(responseMap.toString(), count, is(1));
48+
}
49+
50+
// wait for scale down. The scaler service will check every 10 seconds
51+
assertBusy(() -> {
52+
var statsMap = entityAsMap(getTrainedModelStats(modelId));
53+
List<Map<String, Object>> innerStats = (List<Map<String, Object>>) statsMap.get("trained_model_stats");
54+
Integer innerCount = (Integer) XContentMapValues.extractValue(
55+
"deployment_stats.allocation_status.allocation_count",
56+
innerStats.get(0)
57+
);
58+
assertThat(statsMap.toString(), innerCount, is(0));
59+
}, 30, TimeUnit.SECONDS);
60+
61+
var failures = new ConcurrentLinkedDeque<Exception>();
62+
63+
// infer will scale up
64+
int inferenceCount = 10;
65+
var latch = new CountDownLatch(inferenceCount);
66+
for (int i = 0; i < inferenceCount; i++) {
67+
asyncInfer("Auto scale and infer", modelId, TimeValue.timeValueSeconds(5), new ResponseListener() {
68+
@Override
69+
public void onSuccess(Response response) {
70+
latch.countDown();
71+
}
72+
73+
@Override
74+
public void onFailure(Exception exception) {
75+
latch.countDown();
76+
failures.add(exception);
77+
}
78+
});
79+
}
80+
81+
latch.await();
82+
assertThat(failures, empty());
83+
}
84+
85+
@SuppressWarnings("unchecked")
86+
public void testMultipleDeploymentsWaiting() throws Exception {
87+
String id1 = "test_scale_from_zero_dep_1";
88+
String id2 = "test_scale_from_zero_dep_2";
89+
String id3 = "test_scale_from_zero_dep_3";
90+
var idsList = Arrays.asList(id1, id2, id3);
91+
for (var modelId : idsList) {
92+
createPassThroughModel(modelId);
93+
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
94+
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
95+
96+
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
97+
}
98+
99+
// wait for scale down. The scaler service will check every 10 seconds
100+
assertBusy(() -> {
101+
var statsMap = entityAsMap(getTrainedModelStats("test_scale_from_zero_dep_*"));
102+
List<Map<String, Object>> innerStats = (List<Map<String, Object>>) statsMap.get("trained_model_stats");
103+
assertThat(innerStats, hasSize(3));
104+
for (int i = 0; i < 3; i++) {
105+
Integer innerCount = (Integer) XContentMapValues.extractValue(
106+
"deployment_stats.allocation_status.allocation_count",
107+
innerStats.get(i)
108+
);
109+
assertThat(statsMap.toString(), innerCount, is(0));
110+
}
111+
}, 30, TimeUnit.SECONDS);
112+
113+
// infer will scale up
114+
int inferenceCount = 10;
115+
var latch = new CountDownLatch(inferenceCount);
116+
for (int i = 0; i < inferenceCount; i++) {
117+
asyncInfer("Auto scale and infer", randomFrom(idsList), TimeValue.timeValueSeconds(5), new ResponseListener() {
118+
@Override
119+
public void onSuccess(Response response) {
120+
latch.countDown();
121+
}
122+
123+
@Override
124+
public void onFailure(Exception exception) {
125+
latch.countDown();
126+
fail(exception.getMessage());
127+
}
128+
});
129+
}
130+
131+
latch.await();
132+
}
133+
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010
import org.apache.http.util.EntityUtils;
1111
import org.elasticsearch.client.Request;
1212
import org.elasticsearch.client.Response;
13+
import org.elasticsearch.client.ResponseListener;
14+
import org.elasticsearch.common.Strings;
1315
import org.elasticsearch.common.settings.Settings;
1416
import org.elasticsearch.common.util.concurrent.ThreadContext;
1517
import org.elasticsearch.common.xcontent.support.XContentMapValues;
16-
import org.elasticsearch.core.Strings;
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.test.SecuritySettingsSourceField;
1920
import org.elasticsearch.test.rest.ESRestTestCase;
21+
import org.elasticsearch.xcontent.XContentBuilder;
22+
import org.elasticsearch.xcontent.json.JsonXContent;
23+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2024
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
2125
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
2226
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
@@ -282,6 +286,27 @@ protected Response startDeployment(
282286
return client().performRequest(request);
283287
}
284288

289+
protected Response startDeployment(String modelId, String deploymentId, AdaptiveAllocationsSettings adaptiveAllocationsSettings)
290+
throws IOException {
291+
String endPoint = "/_ml/trained_models/"
292+
+ modelId
293+
+ "/deployment/_start"
294+
+ "?deployment_id="
295+
+ deploymentId
296+
+ "&threads_per_allocation=1"
297+
+ "&wait_for=started";
298+
299+
XContentBuilder builder = JsonXContent.contentBuilder();
300+
builder.startObject();
301+
builder.field("adaptive_allocations", adaptiveAllocationsSettings);
302+
builder.endObject();
303+
var body = Strings.toString(builder);
304+
305+
Request request = new Request("POST", endPoint);
306+
request.setJsonEntity(body);
307+
return client().performRequest(request);
308+
}
309+
285310
protected void stopDeployment(String modelId) throws IOException {
286311
stopDeployment(modelId, false, false);
287312
}
@@ -325,6 +350,14 @@ protected Response infer(String input, String modelId, TimeValue timeout) throws
325350
return client().performRequest(request);
326351
}
327352

353+
protected void asyncInfer(String input, String modelId, TimeValue timeout, ResponseListener responseListener) throws IOException {
354+
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=" + timeout.toString());
355+
request.setJsonEntity(Strings.format("""
356+
{ "docs": [{"input":"%s"}] }
357+
""", input));
358+
client().performRequestAsync(request, responseListener);
359+
}
360+
328361
protected Response infer(String input, String modelId) throws IOException {
329362
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
330363
request.setJsonEntity(Strings.format("""

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExternalInferModelAction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import org.elasticsearch.cluster.service.ClusterService;
1212
import org.elasticsearch.injection.guice.Inject;
1313
import org.elasticsearch.license.XPackLicenseState;
14+
import org.elasticsearch.threadpool.ThreadPool;
1415
import org.elasticsearch.transport.TransportService;
1516
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
1617
import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService;
18+
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
1719
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
1820
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
1921

@@ -27,7 +29,9 @@ public TransportExternalInferModelAction(
2729
ClusterService clusterService,
2830
XPackLicenseState licenseState,
2931
TrainedModelProvider trainedModelProvider,
30-
AdaptiveAllocationsScalerService adaptiveAllocationsScalerService
32+
AdaptiveAllocationsScalerService adaptiveAllocationsScalerService,
33+
TrainedModelAssignmentService assignmentService,
34+
ThreadPool threadPool
3135
) {
3236
super(
3337
InferModelAction.EXTERNAL_NAME,
@@ -38,7 +42,9 @@ public TransportExternalInferModelAction(
3842
clusterService,
3943
licenseState,
4044
trainedModelProvider,
41-
adaptiveAllocationsScalerService
45+
adaptiveAllocationsScalerService,
46+
assignmentService,
47+
threadPool
4248
);
4349
}
4450
}

0 commit comments

Comments
 (0)