Skip to content

Commit 3413592

Browse files
committed
tests
1 parent 0089850 commit 3413592

File tree

14 files changed

+312
-56
lines changed

14 files changed

+312
-56
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,15 @@ public boolean hasStartedRoutes() {
224224
return nodeRoutingTable.values().stream().anyMatch(routeInfo -> routeInfo.getState() == RoutingState.STARTED);
225225
}
226226

227-
public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(
227+
public List<Tuple<String, Integer>> selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
228228
int numberOfRequests,
229-
RoutingState requiredState
229+
RoutingState ... acceptableStates
230230
) {
231231
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
232232
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
233233
int allocationSum = 0;
234234
for (Map.Entry<String, RoutingInfo> routingEntry : nodeRoutingTable.entrySet()) {
235-
if (routingEntry.getValue().getState() == requiredState) {
235+
if (routingEntry.getValue().getState().isAnyOf(acceptableStates)) {
236236
nodeIds.add(routingEntry.getKey());
237237
allocationSum += routingEntry.getValue().getCurrentAllocations();
238238
cumulativeAllocations.add(allocationSum);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,15 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS
195195
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, ""));
196196
TrainedModelAssignment assignment = builder.build();
197197

198-
assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED).isEmpty(), is(true));
198+
assertThat(assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(1, RoutingState.STARTED).isEmpty(), is(true));
199199
}
200200

201201
public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() {
202202
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null);
203203
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
204204
TrainedModelAssignment assignment = builder.build();
205205

206-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);
206+
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(1, RoutingState.STARTED);
207207

208208
assertThat(nodes, contains(new Tuple<>("node-1", 1)));
209209
}
@@ -213,7 +213,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh
213213
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
214214
TrainedModelAssignment assignment = builder.build();
215215

216-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);
216+
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(1, RoutingState.STOPPING);
217217

218218
assertThat(nodes, empty());
219219
}
@@ -223,7 +223,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh
223223
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, ""));
224224
TrainedModelAssignment assignment = builder.build();
225225

226-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);
226+
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(1, RoutingState.STOPPING);
227227

228228
assertThat(nodes, contains(new Tuple<>("node-1", 1)));
229229
}
@@ -234,7 +234,7 @@ public void testSingleRequestWith2Nodes() {
234234
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
235235
TrainedModelAssignment assignment = builder.build();
236236

237-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);
237+
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(1, RoutingState.STARTED);
238238
assertThat(nodes, hasSize(1));
239239
assertEquals(nodes.get(0).v2(), Integer.valueOf(1));
240240
}
@@ -248,7 +248,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul
248248

249249
final int selectionCount = 10000;
250250
final CountAccumulator countsPerNodeAccumulator = new CountAccumulator();
251-
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);
251+
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(selectionCount, RoutingState.STARTED);
252252

253253
assertThat(nodes, hasSize(3));
254254
assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount));
@@ -269,7 +269,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul
269269
builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
270270
TrainedModelAssignment assignment = builder.build();
271271
final int selectionCount = 1000;
272-
var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);
272+
var nodeCounts = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(selectionCount, RoutingState.STARTED);
273273
assertThat(nodeCounts, hasSize(3));
274274

275275
var selectedNodes = new HashSet<String>();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.integration;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.Response;
12+
import org.elasticsearch.client.ResponseListener;
13+
import org.elasticsearch.common.xcontent.support.XContentMapValues;
14+
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
16+
import org.junit.Before;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.TimeUnit;
23+
24+
import static org.hamcrest.Matchers.is;
25+
import static org.hamcrest.Matchers.not;
26+
import static org.hamcrest.Matchers.nullValue;
27+
28+
public class AdaptiveAllocationsScaleFromZeroIT extends PyTorchModelRestTestCase {
29+
30+
@Before
31+
public void setShortScaleToZeroPeriod() throws IOException {
32+
logger.info("setting time");
33+
Request scaleToZeroTime = new Request("PUT", "_cluster/settings");
34+
scaleToZeroTime.setJsonEntity("""
35+
{
36+
"persistent": {
37+
"xpack.ml.adaptive_allocations_scale_to_zero": "2s"
38+
}
39+
}""");
40+
41+
client().performRequest(scaleToZeroTime);
42+
}
43+
44+
@SuppressWarnings("unchecked")
45+
public void testScaleFromZero() throws Exception {
46+
String modelId = "test_scale_from_zero";
47+
createPassThroughModel(modelId);
48+
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
49+
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
50+
51+
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
52+
53+
var responseMap = entityAsMap(getTrainedModelStats(modelId));
54+
List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
55+
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
56+
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
57+
Integer count = (Integer) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
58+
assertThat(responseMap.toString(), count, is(1));
59+
60+
// wait for scale down. The scaler service will check every 10 seconds
61+
assertBusy(() -> {
62+
var statsMap = entityAsMap(getTrainedModelStats(modelId));
63+
List<Map<String, Object>> innerStats = (List<Map<String, Object>>) statsMap.get("trained_model_stats");
64+
Integer innerCount = (Integer) XContentMapValues.extractValue(
65+
"deployment_stats.allocation_status.allocation_count",
66+
innerStats.get(0)
67+
);
68+
assertThat(statsMap.toString(), innerCount, is(0));
69+
}, 30, TimeUnit.SECONDS);
70+
71+
// infer will scale up
72+
int inferenceCount = 10;
73+
var latch = new CountDownLatch(inferenceCount);
74+
for (int i = 0; i < inferenceCount; i++) {
75+
asyncInfer("Auto scale and infer", modelId, TimeValue.timeValueSeconds(5), new ResponseListener() {
76+
@Override
77+
public void onSuccess(Response response) {
78+
latch.countDown();
79+
}
80+
81+
@Override
82+
public void onFailure(Exception exception) {
83+
latch.countDown();
84+
fail(exception.getMessage());
85+
}
86+
});
87+
}
88+
89+
latch.await();
90+
}
91+
92+
// public void testMultipleDeploymentsWaiting() {
93+
//
94+
// }
95+
96+
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010
import org.apache.http.util.EntityUtils;
1111
import org.elasticsearch.client.Request;
1212
import org.elasticsearch.client.Response;
13+
import org.elasticsearch.client.ResponseListener;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.common.collect.Iterators;
1316
import org.elasticsearch.common.settings.Settings;
1417
import org.elasticsearch.common.util.concurrent.ThreadContext;
18+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
19+
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
1520
import org.elasticsearch.common.xcontent.support.XContentMapValues;
16-
import org.elasticsearch.core.Strings;
1721
import org.elasticsearch.core.TimeValue;
1822
import org.elasticsearch.test.SecuritySettingsSourceField;
1923
import org.elasticsearch.test.rest.ESRestTestCase;
24+
import org.elasticsearch.xcontent.XContentBuilder;
25+
import org.elasticsearch.xcontent.json.JsonXContent;
26+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
2027
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
2128
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
2229
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
@@ -282,6 +289,32 @@ protected Response startDeployment(
282289
return client().performRequest(request);
283290
}
284291

292+
protected Response startDeployment(String modelId, String deploymentId, AdaptiveAllocationsSettings adaptiveAllocationsSettings)
293+
throws IOException {
294+
String endPoint = "/_ml/trained_models/"
295+
+ modelId
296+
+ "/deployment/_start"
297+
+ "?deployment_id="
298+
+ deploymentId
299+
+ "&threads_per_allocation=1";
300+
301+
ChunkedToXContentObject innerChunkedContent = params -> Iterators.concat(
302+
ChunkedToXContentHelper.startObject(),
303+
Iterators.single(((builder, p2) -> builder.field("adaptive_allocations", adaptiveAllocationsSettings))),
304+
ChunkedToXContentHelper.endObject()
305+
);
306+
307+
XContentBuilder builder = JsonXContent.contentBuilder();
308+
builder.startObject();
309+
builder.field("adaptive_allocations", adaptiveAllocationsSettings);
310+
builder.endObject();
311+
var body = Strings.toString(builder);
312+
313+
Request request = new Request("POST", endPoint);
314+
request.setJsonEntity(body);
315+
return client().performRequest(request);
316+
}
317+
285318
protected void stopDeployment(String modelId) throws IOException {
286319
stopDeployment(modelId, false, false);
287320
}
@@ -325,6 +358,14 @@ protected Response infer(String input, String modelId, TimeValue timeout) throws
325358
return client().performRequest(request);
326359
}
327360

361+
protected void asyncInfer(String input, String modelId, TimeValue timeout, ResponseListener responseListener) throws IOException {
362+
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=" + timeout.toString());
363+
request.setJsonEntity(Strings.format("""
364+
{ "docs": [{"input":"%s"}] }
365+
""", input));
366+
client().performRequestAsync(request, responseListener);
367+
}
368+
328369
protected Response infer(String input, String modelId) throws IOException {
329370
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
330371
request.setJsonEntity(Strings.format("""

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,18 @@ public void loadExtensions(ExtensionLoader loader) {
758758
*/
759759
public static final int MAX_LOW_PRIORITY_MODELS_PER_NODE = 100;
760760

761+
/**
762+
* The time interval without any requests that has to pass, before scaling down
763+
* to zero allocations.
764+
*/
765+
public static final Setting<TimeValue> ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME = Setting.timeSetting(
766+
"xpack.ml.adaptive_allocations_scale_to_zero",
767+
TimeValue.timeValueMinutes(15),
768+
TimeValue.timeValueSeconds(1),
769+
Property.Dynamic,
770+
Setting.Property.NodeScope
771+
);
772+
761773
private static final Logger logger = LogManager.getLogger(MachineLearning.class);
762774
private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(MachineLearning.class);
763775

@@ -817,7 +829,8 @@ public List<Setting<?>> getSettings() {
817829
MAX_ML_NODE_SIZE,
818830
DELAYED_DATA_CHECK_FREQ,
819831
DUMMY_ENTITY_MEMORY,
820-
DUMMY_ENTITY_PROCESSORS
832+
DUMMY_ENTITY_PROCESSORS,
833+
ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME
821834
);
822835
}
823836

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExternalInferModelAction.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.cluster.service.ClusterService;
1212
import org.elasticsearch.injection.guice.Inject;
1313
import org.elasticsearch.license.XPackLicenseState;
14+
import org.elasticsearch.threadpool.ThreadPool;
1415
import org.elasticsearch.transport.TransportService;
1516
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
1617
import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService;
@@ -29,7 +30,8 @@ public TransportExternalInferModelAction(
2930
XPackLicenseState licenseState,
3031
TrainedModelProvider trainedModelProvider,
3132
AdaptiveAllocationsScalerService adaptiveAllocationsScalerService,
32-
TrainedModelAssignmentService assignmentService
33+
TrainedModelAssignmentService assignmentService,
34+
ThreadPool threadPool
3335
) {
3436
super(
3537
InferModelAction.EXTERNAL_NAME,
@@ -41,7 +43,8 @@ public TransportExternalInferModelAction(
4143
licenseState,
4244
trainedModelProvider,
4345
adaptiveAllocationsScalerService,
44-
assignmentService
46+
assignmentService,
47+
threadPool
4548
);
4649
}
4750
}

0 commit comments

Comments
 (0)