Skip to content

Commit f307f05

Browse files
committed
protect against multiple requests
1 parent 8e17a9a commit f307f05

File tree

7 files changed

+220
-51
lines changed

7 files changed

+220
-51
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ public boolean hasStartedRoutes() {
226226

227227
public List<Tuple<String, Integer>> selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
228228
int numberOfRequests,
229-
RoutingState ... acceptableStates
229+
RoutingState... acceptableStates
230230
) {
231231
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
232232
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AdaptiveAllocationsScaleFromZeroIT.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import org.junit.Before;
1717

1818
import java.io.IOException;
19+
import java.util.Arrays;
1920
import java.util.List;
2021
import java.util.Map;
2122
import java.util.concurrent.CountDownLatch;
2223
import java.util.concurrent.TimeUnit;
2324

25+
import static org.hamcrest.Matchers.hasSize;
2426
import static org.hamcrest.Matchers.is;
2527
import static org.hamcrest.Matchers.not;
2628
import static org.hamcrest.Matchers.nullValue;
@@ -69,7 +71,7 @@ public void testScaleFromZero() throws Exception {
6971
}, 30, TimeUnit.SECONDS);
7072

7173
// infer will scale up
72-
int inferenceCount = 10;
74+
int inferenceCount = 100;
7375
var latch = new CountDownLatch(inferenceCount);
7476
for (int i = 0; i < inferenceCount; i++) {
7577
asyncInfer("Auto scale and infer", modelId, TimeValue.timeValueSeconds(5), new ResponseListener() {
@@ -89,8 +91,52 @@ public void onFailure(Exception exception) {
8991
latch.await();
9092
}
9193

92-
// public void testMultipleDeploymentsWaiting() {
93-
//
94-
// }
94+
@SuppressWarnings("unchecked")
95+
public void testMultipleDeploymentsWaiting() throws Exception {
96+
String id1 = "test_scale_from_zero_dep_1";
97+
String id2 = "test_scale_from_zero_dep_2";
98+
String id3 = "test_scale_from_zero_dep_3";
99+
var idsList = Arrays.asList(id1, id2, id3);
100+
for (var modelId : idsList) {
101+
createPassThroughModel(modelId);
102+
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
103+
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
104+
105+
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
106+
}
107+
108+
// wait for scale down. The scaler service will check every 10 seconds
109+
assertBusy(() -> {
110+
var statsMap = entityAsMap(getTrainedModelStats("test_scale_from_zero_dep_*"));
111+
List<Map<String, Object>> innerStats = (List<Map<String, Object>>) statsMap.get("trained_model_stats");
112+
assertThat(innerStats, hasSize(3));
113+
for (int i = 0; i < 3; i++) {
114+
Integer innerCount = (Integer) XContentMapValues.extractValue(
115+
"deployment_stats.allocation_status.allocation_count",
116+
innerStats.get(i)
117+
);
118+
assertThat(statsMap.toString(), innerCount, is(0));
119+
}
120+
}, 30, TimeUnit.SECONDS);
95121

122+
// infer will scale up
123+
int inferenceCount = 100;
124+
var latch = new CountDownLatch(inferenceCount);
125+
for (int i = 0; i < inferenceCount; i++) {
126+
asyncInfer("Auto scale and infer", randomFrom(idsList), TimeValue.timeValueSeconds(5), new ResponseListener() {
127+
@Override
128+
public void onSuccess(Response response) {
129+
latch.countDown();
130+
}
131+
132+
@Override
133+
public void onFailure(Exception exception) {
134+
latch.countDown();
135+
fail(exception.getMessage());
136+
}
137+
});
138+
}
139+
140+
latch.await();
141+
}
96142
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ private void inferAgainstAllocatedModel(
268268

269269
// We couldn't find any nodes in the started state so let's look for ones that are stopping in case we're shutting down some nodes
270270
if (nodes.isEmpty()) {
271-
nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(request.numberOfDocuments(), RoutingState.STOPPING);
271+
nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
272+
request.numberOfDocuments(),
273+
RoutingState.STOPPING
274+
);
272275
}
273276

274277
if (nodes.isEmpty()) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.cluster.service.ClusterService;
1312
import org.elasticsearch.common.Strings;
14-
import org.elasticsearch.core.TimeValue;
15-
16-
import static org.elasticsearch.xpack.ml.MachineLearning.ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME;
1713

1814
/**
1915
* Processes measured requests counts and inference times and decides whether

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.threadpool.Scheduler;
2626
import org.elasticsearch.threadpool.ThreadPool;
2727
import org.elasticsearch.xpack.core.ClientHelper;
28+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
2829
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
2930
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
3031
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
@@ -40,6 +41,7 @@
4041
import java.util.List;
4142
import java.util.Map;
4243
import java.util.Set;
44+
import java.util.concurrent.ConcurrentSkipListSet;
4345
import java.util.concurrent.atomic.AtomicBoolean;
4446
import java.util.concurrent.atomic.AtomicLong;
4547
import java.util.function.Function;
@@ -217,6 +219,8 @@ void close() {
217219

218220
private final AtomicLong scaleToZeroAfterNoRequestsSeconds = new AtomicLong();
219221

222+
private final Set<String> inFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>();
223+
220224
public AdaptiveAllocationsScalerService(
221225
ThreadPool threadPool,
222226
ClusterService clusterService,
@@ -298,8 +302,11 @@ private synchronized void updateAutoscalers(ClusterState state) {
298302
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE) {
299303
AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.computeIfAbsent(
300304
assignment.getDeploymentId(),
301-
key -> new AdaptiveAllocationsScaler(assignment.getDeploymentId(), assignment.totalTargetAllocations(),
302-
scaleToZeroAfterNoRequestsSeconds.get())
305+
key -> new AdaptiveAllocationsScaler(
306+
assignment.getDeploymentId(),
307+
assignment.totalTargetAllocations(),
308+
scaleToZeroAfterNoRequestsSeconds.get()
309+
)
303310
);
304311
adaptiveAllocationsScaler.setMinMaxNumberOfAllocations(
305312
assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations(),
@@ -427,22 +434,42 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
427434
if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) {
428435
lastScaleUpTimesMillis.put(deploymentId, now);
429436
}
430-
updateNumberOfAllocations(deploymentId, newNumberOfAllocations);
437+
updateNumberOfAllocations(
438+
deploymentId,
439+
newNumberOfAllocations,
440+
updateAssigmentListener(deploymentId, newNumberOfAllocations)
441+
);
431442
}
432443
}
433444
}
434445

435446
public boolean maybeStartAllocation(TrainedModelAssignment assignment) {
436447
if (assignment.getAdaptiveAllocationsSettings() != null
437-
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE) {
438-
lastScaleUpTimesMillis.put(assignment.getDeploymentId(), System.currentTimeMillis());
439-
updateNumberOfAllocations(assignment.getDeploymentId(), 1);
448+
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE
449+
&& assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations() == 0) {
450+
451+
// Prevent against a flurry of scale up requests.
452+
if (inFlightScaleFromZeroRequests.contains(assignment.getDeploymentId()) == false) {
453+
lastScaleUpTimesMillis.put(assignment.getDeploymentId(), System.currentTimeMillis());
454+
var updateListener = updateAssigmentListener(assignment.getDeploymentId(), 1);
455+
var cleanUpListener = ActionListener.runAfter(
456+
updateListener,
457+
() -> inFlightScaleFromZeroRequests.remove(assignment.getDeploymentId())
458+
);
459+
460+
inFlightScaleFromZeroRequests.add(assignment.getDeploymentId());
461+
updateNumberOfAllocations(assignment.getDeploymentId(), 1, cleanUpListener);
462+
}
440463
return true;
441464
}
442465
return false;
443466
}
444467

445-
private void updateNumberOfAllocations(String deploymentId, int numberOfAllocations) {
468+
private void updateNumberOfAllocations(
469+
String deploymentId,
470+
int numberOfAllocations,
471+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener
472+
) {
446473
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
447474
updateRequest.setNumberOfAllocations(numberOfAllocations);
448475
updateRequest.setIsInternal(true);
@@ -451,40 +478,43 @@ private void updateNumberOfAllocations(String deploymentId, int numberOfAllocati
451478
ClientHelper.ML_ORIGIN,
452479
UpdateTrainedModelDeploymentAction.INSTANCE,
453480
updateRequest,
454-
ActionListener.wrap(updateResponse -> {
455-
logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, numberOfAllocations);
456-
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
457-
.execute(
458-
() -> inferenceAuditor.info(
459-
deploymentId,
460-
Strings.format(
461-
"adaptive allocations scaler: scaled [%s] to [%s] allocations.",
462-
deploymentId,
463-
numberOfAllocations
464-
)
465-
)
466-
);
467-
}, e -> {
468-
logger.atLevel(Level.WARN)
469-
.withThrowable(e)
470-
.log("adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", deploymentId, numberOfAllocations);
471-
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
472-
.execute(
473-
() -> inferenceAuditor.warning(
474-
deploymentId,
475-
Strings.format(
476-
"adaptive allocations scaler: scaling [%s] to [%s] allocations failed.",
477-
deploymentId,
478-
numberOfAllocations
479-
)
480-
)
481-
);
482-
})
481+
listener
483482
);
484483
}
485484

486485
private void setScaleToZeroPeriod(TimeValue timeValue) {
487486
logger.info("setting scaler service to zero " + timeValue);
488487
scaleToZeroAfterNoRequestsSeconds.set(timeValue.seconds());
489488
}
489+
490+
private ActionListener<CreateTrainedModelAssignmentAction.Response> updateAssigmentListener(
491+
String deploymentId,
492+
int numberOfAllocations
493+
) {
494+
return ActionListener.wrap(updateResponse -> {
495+
logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, numberOfAllocations);
496+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
497+
.execute(
498+
() -> inferenceAuditor.info(
499+
deploymentId,
500+
Strings.format("adaptive allocations scaler: scaled [%s] to [%s] allocations.", deploymentId, numberOfAllocations)
501+
)
502+
);
503+
}, e -> {
504+
logger.atLevel(Level.WARN)
505+
.withThrowable(e)
506+
.log("adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", deploymentId, numberOfAllocations);
507+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
508+
.execute(
509+
() -> inferenceAuditor.warning(
510+
deploymentId,
511+
Strings.format(
512+
"adaptive allocations scaler: scaling [%s] to [%s] allocations failed.",
513+
deploymentId,
514+
numberOfAllocations
515+
)
516+
)
517+
);
518+
});
519+
}
490520
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/waitforallocations/ScalingInference.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.cluster.ClusterState;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.tasks.TaskId;
1615
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
1716
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
@@ -58,11 +57,10 @@ public ScalingInference(
5857
}
5958

6059
public synchronized void waitForAssignment(WaitingRequest request) {
61-
logger.info("new wait for request");
6260
var p = queueRequests.computeIfAbsent(request.deploymentId(), k -> new LinkedBlockingQueue<>());
6361

6462
if (p.isEmpty()) {
65-
logger.info("will wait for condition");
63+
logger.info("waitForAssignment will wait for condition");
6664
p.offer(request);
6765
assignmentService.waitForAssignmentCondition(
6866
request.deploymentId(),
@@ -71,7 +69,7 @@ public synchronized void waitForAssignment(WaitingRequest request) {
7169
new WaitingListener(request.deploymentId())
7270
);
7371
} else {
74-
logger.info("added to queue");
72+
logger.info("waitForAssignment added to queue");
7573
p.offer(request);
7674
}
7775

0 commit comments

Comments
 (0)