Skip to content

Commit 12256ae

Browse files
committed
Remove queue
1 parent f307f05 commit 12256ae

File tree

8 files changed

+246
-167
lines changed

8 files changed

+246
-167
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AdaptiveAllocationsScaleFromZeroIT.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Arrays;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.concurrent.ConcurrentLinkedDeque;
2223
import java.util.concurrent.CountDownLatch;
2324
import java.util.concurrent.TimeUnit;
2425

@@ -36,7 +37,7 @@ public void setShortScaleToZeroPeriod() throws IOException {
3637
scaleToZeroTime.setJsonEntity("""
3738
{
3839
"persistent": {
39-
"xpack.ml.adaptive_allocations_scale_to_zero": "2s"
40+
"xpack.ml.adaptive_allocations_scale_to_zero_interval": "2s"
4041
}
4142
}""");
4243

@@ -51,13 +52,14 @@ public void testScaleFromZero() throws Exception {
5152
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
5253

5354
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
54-
55-
var responseMap = entityAsMap(getTrainedModelStats(modelId));
56-
List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
57-
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
58-
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
59-
Integer count = (Integer) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
60-
assertThat(responseMap.toString(), count, is(1));
55+
{
56+
var responseMap = entityAsMap(getTrainedModelStats(modelId));
57+
List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
58+
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
59+
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
60+
Integer count = (Integer) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
61+
assertThat(responseMap.toString(), count, is(1));
62+
}
6163

6264
// wait for scale down. The scaler service will check every 10 seconds
6365
assertBusy(() -> {
@@ -70,8 +72,10 @@ public void testScaleFromZero() throws Exception {
7072
assertThat(statsMap.toString(), innerCount, is(0));
7173
}, 30, TimeUnit.SECONDS);
7274

75+
var failures = new ConcurrentLinkedDeque<Exception>();
76+
7377
// infer will scale up
74-
int inferenceCount = 100;
78+
int inferenceCount = 10;
7579
var latch = new CountDownLatch(inferenceCount);
7680
for (int i = 0; i < inferenceCount; i++) {
7781
asyncInfer("Auto scale and infer", modelId, TimeValue.timeValueSeconds(5), new ResponseListener() {
@@ -83,12 +87,24 @@ public void onSuccess(Response response) {
8387
@Override
8488
public void onFailure(Exception exception) {
8589
latch.countDown();
86-
fail(exception.getMessage());
90+
failures.add(exception);
8791
}
8892
});
8993
}
9094

9195
latch.await();
96+
if (failures.isEmpty() == false) {
97+
fail(failures.getFirst());
98+
}
99+
100+
// {
101+
// var responseMap = entityAsMap(getTrainedModelStats(modelId));
102+
// List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
103+
// String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
104+
// assertThat(responseMap.toString(), statusState, is(not(nullValue())));
105+
// Integer count = (Integer) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
106+
// assertThat(responseMap.toString(), count, greaterThan(0));
107+
// }
92108
}
93109

94110
@SuppressWarnings("unchecked")

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ protected Response startDeployment(String modelId, String deploymentId, Adaptive
296296
+ "/deployment/_start"
297297
+ "?deployment_id="
298298
+ deploymentId
299-
+ "&threads_per_allocation=1";
299+
+ "&threads_per_allocation=1"
300+
+ "&wait_for=started";
300301

301302
ChunkedToXContentObject innerChunkedContent = params -> Iterators.concat(
302303
ChunkedToXContentHelper.startObject(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,8 @@ public void loadExtensions(ExtensionLoader loader) {
770770
* The time interval without any requests that has to pass, before scaling down
771771
* to zero allocations.
772772
*/
773-
public static final Setting<TimeValue> ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME = Setting.timeSetting(
774-
"xpack.ml.adaptive_allocations_scale_to_zero",
773+
public static final Setting<TimeValue> ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_INTERVAL = Setting.timeSetting(
774+
"xpack.ml.adaptive_allocations_scale_to_zero_interval",
775775
TimeValue.timeValueMinutes(15),
776776
TimeValue.timeValueSeconds(1),
777777
Property.Dynamic,
@@ -838,7 +838,7 @@ public List<Setting<?>> getSettings() {
838838
DELAYED_DATA_CHECK_FREQ,
839839
DUMMY_ENTITY_MEMORY,
840840
DUMMY_ENTITY_PROCESSORS,
841-
ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME
841+
ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_INTERVAL
842842
);
843843
}
844844

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@
4343
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
4444
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
4545
import org.elasticsearch.xpack.ml.MachineLearning;
46+
import org.elasticsearch.xpack.ml.inference.InferenceWaitForAllocation;
4647
import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService;
4748
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
4849
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
4950
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
5051
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
51-
import org.elasticsearch.xpack.ml.inference.waitforallocations.ScalingInference;
5252
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
5353

5454
import java.util.Collections;
@@ -71,7 +71,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
7171
private final XPackLicenseState licenseState;
7272
private final TrainedModelProvider trainedModelProvider;
7373
private final AdaptiveAllocationsScalerService adaptiveAllocationsScalerService;
74-
private final ScalingInference scalingInference;
74+
private final InferenceWaitForAllocation scalingInference;
7575
private final ThreadPool threadPool;
7676

7777
TransportInternalInferModelAction(
@@ -94,7 +94,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
9494
this.licenseState = licenseState;
9595
this.trainedModelProvider = trainedModelProvider;
9696
this.adaptiveAllocationsScalerService = adaptiveAllocationsScalerService;
97-
this.scalingInference = new ScalingInference(assignmentService, this::inferOnBlockedRequest);
97+
this.scalingInference = new InferenceWaitForAllocation(assignmentService, this::inferOnBlockedRequest);
9898
this.threadPool = threadPool;
9999
}
100100

@@ -280,7 +280,9 @@ private void inferAgainstAllocatedModel(
280280
if (starting) {
281281
message += "; starting deployment of one allocation";
282282
logger.info(message);
283-
scalingInference.waitForAssignment(new ScalingInference.WaitingRequest(request, responseBuilder, parentTaskId, listener));
283+
scalingInference.waitForAssignment(
284+
new InferenceWaitForAllocation.WaitingRequest(request, responseBuilder, parentTaskId, listener)
285+
);
284286
return;
285287
}
286288

@@ -293,7 +295,7 @@ private void inferAgainstAllocatedModel(
293295
: "mismatch; sum of node requests does not match number of documents in request";
294296
}
295297

296-
private void inferOnBlockedRequest(ScalingInference.WaitingRequest request, TrainedModelAssignment assignment) {
298+
private void inferOnBlockedRequest(InferenceWaitForAllocation.WaitingRequest request, TrainedModelAssignment assignment) {
297299
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
298300

299301
var nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.inference;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
13+
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.ClusterState;
15+
import org.elasticsearch.rest.RestStatus;
16+
import org.elasticsearch.tasks.TaskId;
17+
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
18+
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
19+
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
20+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
21+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
22+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
23+
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
24+
25+
import java.util.HashMap;
26+
import java.util.Map;
27+
import java.util.concurrent.atomic.AtomicInteger;
28+
import java.util.concurrent.atomic.AtomicReference;
29+
import java.util.function.BiConsumer;
30+
import java.util.function.Predicate;
31+
32+
import static org.elasticsearch.core.Strings.format;
33+
34+
/**
35+
* Class for storing inference requests for ml trained models while
36+
* scaling is in progress. Once the trained model has at least 1
37+
* allocation the stored requests are forwarded to a consumer for
38+
* processing.Requests will timeout while waiting for scale.
39+
*/
40+
public class InferenceWaitForAllocation {
41+
42+
public static final int MAX_PENDING_REQUEST_COUNT = 100;
43+
44+
/**
45+
* Track details of the pending request
46+
*/
47+
public record WaitingRequest(
48+
InferModelAction.Request request,
49+
InferModelAction.Response.Builder responseBuilder,
50+
TaskId parentTaskId,
51+
ActionListener<InferModelAction.Response> listener
52+
) {
53+
public String deploymentId() {
54+
return request.getId();
55+
}
56+
}
57+
58+
private static final Logger logger = LogManager.getLogger(InferenceWaitForAllocation.class);
59+
60+
private final TrainedModelAssignmentService assignmentService;
61+
private final BiConsumer<WaitingRequest, TrainedModelAssignment> queuedConsumer;
62+
private AtomicInteger pendingRequestCount = new AtomicInteger();
63+
64+
/**
65+
* Create with consumer of the successful requests
66+
* @param assignmentService Trained model assignment service
67+
* @param onInferenceScaledConsumer The consumer of the waiting request called once an
68+
* allocation is available.
69+
*/
70+
public InferenceWaitForAllocation(
71+
TrainedModelAssignmentService assignmentService,
72+
BiConsumer<WaitingRequest, TrainedModelAssignment> onInferenceScaledConsumer
73+
) {
74+
this.assignmentService = assignmentService;
75+
this.queuedConsumer = onInferenceScaledConsumer;
76+
}
77+
78+
/**
79+
* Wait for at least 1 allocation to be started then process the
80+
* inference request.
81+
* If the pending request count is greater than {@link #MAX_PENDING_REQUEST_COUNT}
82+
* the request listener is failed with a too many requests exception
83+
* The timeout is the inference request timeout.
84+
* @param request The inference request details
85+
*/
86+
public synchronized void waitForAssignment(WaitingRequest request) {
87+
logger.info("waitForAssignment will wait for condition");
88+
if (pendingRequestCount.get() > MAX_PENDING_REQUEST_COUNT) {
89+
request.listener.onFailure(
90+
new ElasticsearchStatusException(
91+
"Rejected inference request waiting for an allocation of deployment [{}]. Too many pending requests",
92+
RestStatus.TOO_MANY_REQUESTS,
93+
request.request.getId()
94+
)
95+
);
96+
return;
97+
}
98+
99+
pendingRequestCount.incrementAndGet();
100+
var prediate = new DeploymentHasAtLeastOneAllocation(request.deploymentId());
101+
102+
assignmentService.waitForAssignmentCondition(
103+
request.deploymentId(),
104+
prediate,
105+
request.request().getInferenceTimeout(),
106+
new WaitingListener(request.deploymentId(), request, prediate)
107+
);
108+
}
109+
110+
private static class DeploymentHasAtLeastOneAllocation implements Predicate<ClusterState> {
111+
112+
private final String deploymentId;
113+
private AtomicReference<Exception> exception = new AtomicReference<>();
114+
115+
DeploymentHasAtLeastOneAllocation(String deploymentId) {
116+
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, "deployment_id");
117+
}
118+
119+
@Override
120+
public boolean test(ClusterState clusterState) {
121+
logger.info("predicate test");
122+
TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentMetadata.assignmentForDeploymentId(
123+
clusterState,
124+
deploymentId
125+
).orElse(null);
126+
if (trainedModelAssignment == null) {
127+
logger.info(() -> format("[%s] assignment was null while waiting to scale up", deploymentId));
128+
return false;
129+
}
130+
131+
Map<String, String> nodeFailuresAndReasons = new HashMap<>();
132+
for (var nodeIdAndRouting : trainedModelAssignment.getNodeRoutingTable().entrySet()) {
133+
if (RoutingState.FAILED.equals(nodeIdAndRouting.getValue().getState())) {
134+
nodeFailuresAndReasons.put(nodeIdAndRouting.getKey(), nodeIdAndRouting.getValue().getReason());
135+
}
136+
}
137+
if (nodeFailuresAndReasons.isEmpty() == false) {
138+
if (nodeFailuresAndReasons.size() == trainedModelAssignment.getNodeRoutingTable().size()) {
139+
exception.set(
140+
new ElasticsearchStatusException(
141+
"[{}] Error waiting for a model allocation, all nodes have failed with errors [{}]",
142+
RestStatus.INTERNAL_SERVER_ERROR,
143+
trainedModelAssignment.getDeploymentId(),
144+
nodeFailuresAndReasons
145+
)
146+
);
147+
return true; // don't try again
148+
} else {
149+
logger.warn("Deployment [{}] has failed routes [{}]", trainedModelAssignment.getDeploymentId(), nodeFailuresAndReasons);
150+
}
151+
}
152+
153+
var routable = trainedModelAssignment.getNodeRoutingTable().values().stream().filter(RoutingInfo::isRoutable).findFirst();
154+
if (routable.isPresent()) {
155+
logger.info("first route " + routable.get() + ", state" + trainedModelAssignment.calculateAllocationStatus());
156+
} else {
157+
logger.info("no routes");
158+
}
159+
160+
return routable.isPresent();
161+
}
162+
}
163+
164+
private class WaitingListener implements TrainedModelAssignmentService.WaitForAssignmentListener {
165+
166+
private final String deploymentId;
167+
private final WaitingRequest request;
168+
private final DeploymentHasAtLeastOneAllocation predicate;
169+
170+
private WaitingListener(String deploymentId, WaitingRequest request, DeploymentHasAtLeastOneAllocation predicate) {
171+
this.deploymentId = deploymentId;
172+
this.request = request;
173+
this.predicate = predicate;
174+
}
175+
176+
@Override
177+
public void onResponse(TrainedModelAssignment assignment) {
178+
// assignment is started, do inference
179+
pendingRequestCount.decrementAndGet();
180+
181+
if (predicate.exception.get() != null) {
182+
onFailure(predicate.exception.get());
183+
return;
184+
}
185+
186+
logger.info("sending waited request");
187+
queuedConsumer.accept(request, assignment);
188+
}
189+
190+
@Override
191+
public void onFailure(Exception e) {
192+
logger.info("failed waiting", e);
193+
pendingRequestCount.decrementAndGet();
194+
request.listener().onFailure(e);
195+
}
196+
}
197+
}

0 commit comments

Comments
 (0)