Skip to content

Commit 7a56691

Browse files
committed
protect against multiple requests
1 parent 3413592 commit 7a56691

File tree

7 files changed

+220
-51
lines changed

7 files changed

+220
-51
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ public boolean hasStartedRoutes() {
226226

227227
public List<Tuple<String, Integer>> selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
228228
int numberOfRequests,
229-
RoutingState ... acceptableStates
229+
RoutingState... acceptableStates
230230
) {
231231
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
232232
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AdaptiveAllocationsScaleFromZeroIT.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import org.junit.Before;
1717

1818
import java.io.IOException;
19+
import java.util.Arrays;
1920
import java.util.List;
2021
import java.util.Map;
2122
import java.util.concurrent.CountDownLatch;
2223
import java.util.concurrent.TimeUnit;
2324

25+
import static org.hamcrest.Matchers.hasSize;
2426
import static org.hamcrest.Matchers.is;
2527
import static org.hamcrest.Matchers.not;
2628
import static org.hamcrest.Matchers.nullValue;
@@ -69,7 +71,7 @@ public void testScaleFromZero() throws Exception {
6971
}, 30, TimeUnit.SECONDS);
7072

7173
// infer will scale up
72-
int inferenceCount = 10;
74+
int inferenceCount = 100;
7375
var latch = new CountDownLatch(inferenceCount);
7476
for (int i = 0; i < inferenceCount; i++) {
7577
asyncInfer("Auto scale and infer", modelId, TimeValue.timeValueSeconds(5), new ResponseListener() {
@@ -89,8 +91,52 @@ public void onFailure(Exception exception) {
8991
latch.await();
9092
}
9193

92-
// public void testMultipleDeploymentsWaiting() {
93-
//
94-
// }
94+
@SuppressWarnings("unchecked")
95+
public void testMultipleDeploymentsWaiting() throws Exception {
96+
String id1 = "test_scale_from_zero_dep_1";
97+
String id2 = "test_scale_from_zero_dep_2";
98+
String id3 = "test_scale_from_zero_dep_3";
99+
var idsList = Arrays.asList(id1, id2, id3);
100+
for (var modelId : idsList) {
101+
createPassThroughModel(modelId);
102+
putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE);
103+
putVocabulary(List.of("Auto", "scale", "and", "infer"), modelId);
104+
105+
startDeployment(modelId, modelId, new AdaptiveAllocationsSettings(true, 0, 1));
106+
}
107+
108+
// wait for scale down. The scaler service will check every 10 seconds
109+
assertBusy(() -> {
110+
var statsMap = entityAsMap(getTrainedModelStats("test_scale_from_zero_dep_*"));
111+
List<Map<String, Object>> innerStats = (List<Map<String, Object>>) statsMap.get("trained_model_stats");
112+
assertThat(innerStats, hasSize(3));
113+
for (int i = 0; i < 3; i++) {
114+
Integer innerCount = (Integer) XContentMapValues.extractValue(
115+
"deployment_stats.allocation_status.allocation_count",
116+
innerStats.get(i)
117+
);
118+
assertThat(statsMap.toString(), innerCount, is(0));
119+
}
120+
}, 30, TimeUnit.SECONDS);
95121

122+
// infer will scale up
123+
int inferenceCount = 100;
124+
var latch = new CountDownLatch(inferenceCount);
125+
for (int i = 0; i < inferenceCount; i++) {
126+
asyncInfer("Auto scale and infer", randomFrom(idsList), TimeValue.timeValueSeconds(5), new ResponseListener() {
127+
@Override
128+
public void onSuccess(Response response) {
129+
latch.countDown();
130+
}
131+
132+
@Override
133+
public void onFailure(Exception exception) {
134+
latch.countDown();
135+
fail(exception.getMessage());
136+
}
137+
});
138+
}
139+
140+
latch.await();
141+
}
96142
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ private void inferAgainstAllocatedModel(
268268

269269
// We couldn't find any nodes in the started state so let's look for ones that are stopping in case we're shutting down some nodes
270270
if (nodes.isEmpty()) {
271-
nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(request.numberOfDocuments(), RoutingState.STOPPING);
271+
nodes = assignment.selectRandomNodesWeighedOnAllocationsForNRequestsAndState(
272+
request.numberOfDocuments(),
273+
RoutingState.STOPPING
274+
);
272275
}
273276

274277
if (nodes.isEmpty()) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.cluster.service.ClusterService;
1312
import org.elasticsearch.common.Strings;
14-
import org.elasticsearch.core.TimeValue;
15-
16-
import static org.elasticsearch.xpack.ml.MachineLearning.ADAPTIVE_ALLOCATIONS_SCALE_TO_ZERO_TIME;
1713

1814
/**
1915
* Processes measured requests counts and inference times and decides whether

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.threadpool.Scheduler;
2626
import org.elasticsearch.threadpool.ThreadPool;
2727
import org.elasticsearch.xpack.core.ClientHelper;
28+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
2829
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
2930
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
3031
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
@@ -40,6 +41,7 @@
4041
import java.util.List;
4142
import java.util.Map;
4243
import java.util.Set;
44+
import java.util.concurrent.ConcurrentSkipListSet;
4345
import java.util.concurrent.atomic.AtomicBoolean;
4446
import java.util.concurrent.atomic.AtomicLong;
4547
import java.util.function.Function;
@@ -206,6 +208,8 @@ Collection<DoubleWithAttributes> observeDouble(Function<AdaptiveAllocationsScale
206208

207209
private final AtomicLong scaleToZeroAfterNoRequestsSeconds = new AtomicLong();
208210

211+
private final Set<String> inFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>();
212+
209213
public AdaptiveAllocationsScalerService(
210214
ThreadPool threadPool,
211215
ClusterService clusterService,
@@ -287,8 +291,11 @@ private synchronized void updateAutoscalers(ClusterState state) {
287291
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE) {
288292
AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.computeIfAbsent(
289293
assignment.getDeploymentId(),
290-
key -> new AdaptiveAllocationsScaler(assignment.getDeploymentId(), assignment.totalTargetAllocations(),
291-
scaleToZeroAfterNoRequestsSeconds.get())
294+
key -> new AdaptiveAllocationsScaler(
295+
assignment.getDeploymentId(),
296+
assignment.totalTargetAllocations(),
297+
scaleToZeroAfterNoRequestsSeconds.get()
298+
)
292299
);
293300
adaptiveAllocationsScaler.setMinMaxNumberOfAllocations(
294301
assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations(),
@@ -416,22 +423,42 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
416423
if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) {
417424
lastScaleUpTimesMillis.put(deploymentId, now);
418425
}
419-
updateNumberOfAllocations(deploymentId, newNumberOfAllocations);
426+
updateNumberOfAllocations(
427+
deploymentId,
428+
newNumberOfAllocations,
429+
updateAssigmentListener(deploymentId, newNumberOfAllocations)
430+
);
420431
}
421432
}
422433
}
423434

424435
public boolean maybeStartAllocation(TrainedModelAssignment assignment) {
425436
if (assignment.getAdaptiveAllocationsSettings() != null
426-
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE) {
427-
lastScaleUpTimesMillis.put(assignment.getDeploymentId(), System.currentTimeMillis());
428-
updateNumberOfAllocations(assignment.getDeploymentId(), 1);
437+
&& assignment.getAdaptiveAllocationsSettings().getEnabled() == Boolean.TRUE
438+
&& assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations() == 0) {
439+
440+
// Prevent against a flurry of scale up requests.
441+
if (inFlightScaleFromZeroRequests.contains(assignment.getDeploymentId()) == false) {
442+
lastScaleUpTimesMillis.put(assignment.getDeploymentId(), System.currentTimeMillis());
443+
var updateListener = updateAssigmentListener(assignment.getDeploymentId(), 1);
444+
var cleanUpListener = ActionListener.runAfter(
445+
updateListener,
446+
() -> inFlightScaleFromZeroRequests.remove(assignment.getDeploymentId())
447+
);
448+
449+
inFlightScaleFromZeroRequests.add(assignment.getDeploymentId());
450+
updateNumberOfAllocations(assignment.getDeploymentId(), 1, cleanUpListener);
451+
}
429452
return true;
430453
}
431454
return false;
432455
}
433456

434-
private void updateNumberOfAllocations(String deploymentId, int numberOfAllocations) {
457+
private void updateNumberOfAllocations(
458+
String deploymentId,
459+
int numberOfAllocations,
460+
ActionListener<CreateTrainedModelAssignmentAction.Response> listener
461+
) {
435462
UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
436463
updateRequest.setNumberOfAllocations(numberOfAllocations);
437464
updateRequest.setIsInternal(true);
@@ -440,40 +467,43 @@ private void updateNumberOfAllocations(String deploymentId, int numberOfAllocati
440467
ClientHelper.ML_ORIGIN,
441468
UpdateTrainedModelDeploymentAction.INSTANCE,
442469
updateRequest,
443-
ActionListener.wrap(updateResponse -> {
444-
logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, numberOfAllocations);
445-
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
446-
.execute(
447-
() -> inferenceAuditor.info(
448-
deploymentId,
449-
Strings.format(
450-
"adaptive allocations scaler: scaled [%s] to [%s] allocations.",
451-
deploymentId,
452-
numberOfAllocations
453-
)
454-
)
455-
);
456-
}, e -> {
457-
logger.atLevel(Level.WARN)
458-
.withThrowable(e)
459-
.log("adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", deploymentId, numberOfAllocations);
460-
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
461-
.execute(
462-
() -> inferenceAuditor.warning(
463-
deploymentId,
464-
Strings.format(
465-
"adaptive allocations scaler: scaling [%s] to [%s] allocations failed.",
466-
deploymentId,
467-
numberOfAllocations
468-
)
469-
)
470-
);
471-
})
470+
listener
472471
);
473472
}
474473

475474
private void setScaleToZeroPeriod(TimeValue timeValue) {
476475
logger.info("setting scaler service to zero " + timeValue);
477476
scaleToZeroAfterNoRequestsSeconds.set(timeValue.seconds());
478477
}
478+
479+
private ActionListener<CreateTrainedModelAssignmentAction.Response> updateAssigmentListener(
480+
String deploymentId,
481+
int numberOfAllocations
482+
) {
483+
return ActionListener.wrap(updateResponse -> {
484+
logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, numberOfAllocations);
485+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
486+
.execute(
487+
() -> inferenceAuditor.info(
488+
deploymentId,
489+
Strings.format("adaptive allocations scaler: scaled [%s] to [%s] allocations.", deploymentId, numberOfAllocations)
490+
)
491+
);
492+
}, e -> {
493+
logger.atLevel(Level.WARN)
494+
.withThrowable(e)
495+
.log("adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", deploymentId, numberOfAllocations);
496+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
497+
.execute(
498+
() -> inferenceAuditor.warning(
499+
deploymentId,
500+
Strings.format(
501+
"adaptive allocations scaler: scaling [%s] to [%s] allocations failed.",
502+
deploymentId,
503+
numberOfAllocations
504+
)
505+
)
506+
);
507+
});
508+
}
479509
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/waitforallocations/ScalingInference.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.cluster.ClusterState;
14-
import org.elasticsearch.common.Strings;
1514
import org.elasticsearch.tasks.TaskId;
1615
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
1716
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
@@ -58,11 +57,10 @@ public ScalingInference(
5857
}
5958

6059
public synchronized void waitForAssignment(WaitingRequest request) {
61-
logger.info("new wait for request");
6260
var p = queueRequests.computeIfAbsent(request.deploymentId(), k -> new LinkedBlockingQueue<>());
6361

6462
if (p.isEmpty()) {
65-
logger.info("will wait for condition");
63+
logger.info("waitForAssignment will wait for condition");
6664
p.offer(request);
6765
assignmentService.waitForAssignmentCondition(
6866
request.deploymentId(),
@@ -71,7 +69,7 @@ public synchronized void waitForAssignment(WaitingRequest request) {
7169
new WaitingListener(request.deploymentId())
7270
);
7371
} else {
74-
logger.info("added to queue");
72+
logger.info("waitForAssignment added to queue");
7573
p.offer(request);
7674
}
7775

0 commit comments

Comments
 (0)