Skip to content

Commit c6ac294

Browse files
authored
Don't immediately scale down startups triggered by non-master nodes in inference adaptive allocations. (#125297) (#125373)
1 parent b08a6c0 commit c6ac294

File tree

2 files changed

+210
-16
lines changed

2 files changed

+210
-16
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
3030
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
3131
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
32+
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
3233
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
3334
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
3435
import org.elasticsearch.xpack.ml.MachineLearning;
@@ -212,6 +213,7 @@ Collection<DoubleWithAttributes> observeDouble(Function<AdaptiveAllocationsScale
212213
private volatile Scheduler.Cancellable cancellable;
213214
private final AtomicBoolean busy;
214215
private final long scaleToZeroAfterNoRequestsSeconds;
216+
private final long scaleUpCooldownTimeMillis;
215217
private final Set<String> deploymentIdsWithInFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>();
216218
private final Map<String, String> lastWarningMessages = new ConcurrentHashMap<>();
217219

@@ -223,7 +225,17 @@ public AdaptiveAllocationsScalerService(
223225
MeterRegistry meterRegistry,
224226
boolean isNlpEnabled
225227
) {
226-
this(threadPool, clusterService, client, inferenceAuditor, meterRegistry, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS);
228+
this(
229+
threadPool,
230+
clusterService,
231+
client,
232+
inferenceAuditor,
233+
meterRegistry,
234+
isNlpEnabled,
235+
DEFAULT_TIME_INTERVAL_SECONDS,
236+
SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS,
237+
SCALE_UP_COOLDOWN_TIME_MILLIS
238+
);
227239
}
228240

229241
// visible for testing
@@ -234,7 +246,9 @@ public AdaptiveAllocationsScalerService(
234246
InferenceAuditor inferenceAuditor,
235247
MeterRegistry meterRegistry,
236248
boolean isNlpEnabled,
237-
int timeIntervalSeconds
249+
int timeIntervalSeconds,
250+
long scaleToZeroAfterNoRequestsSeconds,
251+
long scaleUpCooldownTimeMillis
238252
) {
239253
this.threadPool = threadPool;
240254
this.clusterService = clusterService;
@@ -243,14 +257,15 @@ public AdaptiveAllocationsScalerService(
243257
this.meterRegistry = meterRegistry;
244258
this.isNlpEnabled = isNlpEnabled;
245259
this.timeIntervalSeconds = timeIntervalSeconds;
260+
this.scaleToZeroAfterNoRequestsSeconds = scaleToZeroAfterNoRequestsSeconds;
261+
this.scaleUpCooldownTimeMillis = scaleUpCooldownTimeMillis;
246262

247263
lastInferenceStatsByDeploymentAndNode = new HashMap<>();
248264
lastInferenceStatsTimestampMillis = null;
249265
lastScaleUpTimesMillis = new HashMap<>();
250266
scalers = new HashMap<>();
251267
metrics = new Metrics();
252268
busy = new AtomicBoolean(false);
253-
scaleToZeroAfterNoRequestsSeconds = SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS;
254269
}
255270

256271
public synchronized void start() {
@@ -374,6 +389,9 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
374389

375390
Map<String, Stats> recentStatsByDeployment = new HashMap<>();
376391
Map<String, Integer> numberOfAllocations = new HashMap<>();
392+
// Check for recent scale ups in the deployment stats, because a different node may have
393+
// caused a scale up when an inference request arrives and there were zero allocations.
394+
Set<String> hasRecentObservedScaleUp = new HashSet<>();
377395

378396
for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
379397
String deploymentId = assignmentStats.getDeploymentId();
@@ -399,6 +417,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
399417
(key, value) -> value == null ? recentStats : value.add(recentStats)
400418
);
401419
}
420+
if (nodeStats.getRoutingState() != null && nodeStats.getRoutingState().getState() == RoutingState.STARTING) {
421+
hasRecentObservedScaleUp.add(deploymentId);
422+
}
423+
if (nodeStats.getStartTime() != null && now < nodeStats.getStartTime().toEpochMilli() + scaleUpCooldownTimeMillis) {
424+
hasRecentObservedScaleUp.add(deploymentId);
425+
}
402426
}
403427
}
404428

@@ -414,9 +438,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
414438
Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale();
415439
if (newNumberOfAllocations != null) {
416440
Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId);
441+
// hasRecentScaleUp indicates whether this service has recently scaled up the deployment.
442+
// hasRecentObservedScaleUp indicates whether a deployment recently has started,
443+
// potentially triggered by another node.
444+
boolean hasRecentScaleUp = lastScaleUpTimeMillis != null && now < lastScaleUpTimeMillis + scaleUpCooldownTimeMillis;
417445
if (newNumberOfAllocations < numberOfAllocations.get(deploymentId)
418-
&& lastScaleUpTimeMillis != null
419-
&& now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) {
446+
&& (hasRecentScaleUp || hasRecentObservedScaleUp.contains(deploymentId))) {
420447
logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId);
421448
continue;
422449
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java

Lines changed: 178 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
import org.junit.After;
3737
import org.junit.Before;
3838

39-
import java.io.IOException;
4039
import java.time.Instant;
40+
import java.time.temporal.ChronoUnit;
4141
import java.util.List;
4242
import java.util.Map;
4343
import java.util.Set;
@@ -114,7 +114,12 @@ private ClusterState getClusterState(int numAllocations) {
114114
return clusterState;
115115
}
116116

117-
private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) {
117+
private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
118+
int numAllocations,
119+
int inferenceCount,
120+
double latency,
121+
boolean recentStartup
122+
) {
118123
return new GetDeploymentStatsAction.Response(
119124
List.of(),
120125
List.of(),
@@ -127,7 +132,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
127132
new AdaptiveAllocationsSettings(true, null, null),
128133
1024,
129134
ByteSizeValue.ZERO,
130-
Instant.now(),
135+
Instant.now().minus(1, ChronoUnit.DAYS),
131136
List.of(
132137
AssignmentStats.NodeStats.forStartedState(
133138
DiscoveryNodeUtils.create("node_1"),
@@ -140,7 +145,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
140145
0,
141146
0,
142147
Instant.now(),
143-
Instant.now(),
148+
recentStartup ? Instant.now() : Instant.now().minus(1, ChronoUnit.HOURS),
144149
1,
145150
numAllocations,
146151
inferenceCount,
@@ -156,7 +161,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
156161
);
157162
}
158163

159-
public void test() throws IOException {
164+
public void test_scaleUp() {
160165
// Initialize the cluster with a deployment with 1 allocation.
161166
ClusterState clusterState = getClusterState(1);
162167
when(clusterService.state()).thenReturn(clusterState);
@@ -168,7 +173,9 @@ public void test() throws IOException {
168173
inferenceAuditor,
169174
meterRegistry,
170175
true,
171-
1
176+
1,
177+
60,
178+
60_000
172179
);
173180
service.start();
174181

@@ -182,7 +189,7 @@ public void test() throws IOException {
182189
doAnswer(invocationOnMock -> {
183190
@SuppressWarnings("unchecked")
184191
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
185-
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0));
192+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
186193
return Void.TYPE;
187194
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
188195

@@ -198,7 +205,7 @@ public void test() throws IOException {
198205
doAnswer(invocationOnMock -> {
199206
@SuppressWarnings("unchecked")
200207
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
201-
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0));
208+
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false));
202209
return Void.TYPE;
203210
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
204211
doAnswer(invocationOnMock -> {
@@ -229,7 +236,137 @@ public void test() throws IOException {
229236
doAnswer(invocationOnMock -> {
230237
@SuppressWarnings("unchecked")
231238
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
232-
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0));
239+
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false));
240+
return Void.TYPE;
241+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
242+
doAnswer(invocationOnMock -> {
243+
@SuppressWarnings("unchecked")
244+
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
245+
listener.onResponse(null);
246+
return Void.TYPE;
247+
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
248+
249+
safeSleep(1000);
250+
251+
verify(client, times(1)).threadPool();
252+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
253+
verifyNoMoreInteractions(client, clusterService);
254+
255+
service.stop();
256+
}
257+
258+
public void test_scaleDownToZero_whenNoRequests() {
259+
// Initialize the cluster with a deployment with 1 allocation.
260+
ClusterState clusterState = getClusterState(1);
261+
when(clusterService.state()).thenReturn(clusterState);
262+
263+
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
264+
threadPool,
265+
clusterService,
266+
client,
267+
inferenceAuditor,
268+
meterRegistry,
269+
true,
270+
1,
271+
1,
272+
2_000
273+
);
274+
service.start();
275+
276+
verify(clusterService).state();
277+
verify(clusterService).addListener(same(service));
278+
verifyNoMoreInteractions(client, clusterService);
279+
reset(client, clusterService);
280+
281+
// First cycle: 1 inference request, so no need for scaling.
282+
when(client.threadPool()).thenReturn(threadPool);
283+
doAnswer(invocationOnMock -> {
284+
@SuppressWarnings("unchecked")
285+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
286+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
287+
return Void.TYPE;
288+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
289+
290+
safeSleep(1200);
291+
292+
verify(client, times(1)).threadPool();
293+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
294+
verifyNoMoreInteractions(client, clusterService);
295+
reset(client, clusterService);
296+
297+
// Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations.
298+
when(client.threadPool()).thenReturn(threadPool);
299+
doAnswer(invocationOnMock -> {
300+
@SuppressWarnings("unchecked")
301+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
302+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
303+
return Void.TYPE;
304+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
305+
doAnswer(invocationOnMock -> {
306+
@SuppressWarnings("unchecked")
307+
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
308+
listener.onResponse(null);
309+
return Void.TYPE;
310+
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
311+
312+
safeSleep(1000);
313+
314+
verify(client, times(2)).threadPool();
315+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
316+
var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
317+
updateRequest.setNumberOfAllocations(0);
318+
updateRequest.setIsInternal(true);
319+
verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
320+
verifyNoMoreInteractions(client, clusterService);
321+
322+
service.stop();
323+
}
324+
325+
public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
326+
// Initialize the cluster with a deployment with 1 allocation.
327+
ClusterState clusterState = getClusterState(1);
328+
when(clusterService.state()).thenReturn(clusterState);
329+
330+
AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
331+
threadPool,
332+
clusterService,
333+
client,
334+
inferenceAuditor,
335+
meterRegistry,
336+
true,
337+
1,
338+
1,
339+
2_000
340+
);
341+
service.start();
342+
343+
verify(clusterService).state();
344+
verify(clusterService).addListener(same(service));
345+
verifyNoMoreInteractions(client, clusterService);
346+
reset(client, clusterService);
347+
348+
// First cycle: 1 inference request, so no need for scaling.
349+
when(client.threadPool()).thenReturn(threadPool);
350+
doAnswer(invocationOnMock -> {
351+
@SuppressWarnings("unchecked")
352+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
353+
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true));
354+
return Void.TYPE;
355+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
356+
357+
safeSleep(1200);
358+
359+
verify(client, times(1)).threadPool();
360+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
361+
verifyNoMoreInteractions(client, clusterService);
362+
reset(client, clusterService);
363+
364+
// Second cycle: 0 inference requests for 1 second, but a recent scale up by another node.
365+
when(client.threadPool()).thenReturn(threadPool);
366+
doAnswer(invocationOnMock -> {
367+
@SuppressWarnings("unchecked")
368+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
369+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true));
233370
return Void.TYPE;
234371
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
235372
doAnswer(invocationOnMock -> {
@@ -244,6 +381,32 @@ public void test() throws IOException {
244381
verify(client, times(1)).threadPool();
245382
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
246383
verifyNoMoreInteractions(client, clusterService);
384+
reset(client, clusterService);
385+
386+
// Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations.
387+
when(client.threadPool()).thenReturn(threadPool);
388+
doAnswer(invocationOnMock -> {
389+
@SuppressWarnings("unchecked")
390+
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
391+
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
392+
return Void.TYPE;
393+
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
394+
doAnswer(invocationOnMock -> {
395+
@SuppressWarnings("unchecked")
396+
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
397+
listener.onResponse(null);
398+
return Void.TYPE;
399+
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
400+
401+
safeSleep(1000);
402+
403+
verify(client, times(2)).threadPool();
404+
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
405+
var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
406+
updateRequest.setNumberOfAllocations(0);
407+
updateRequest.setIsInternal(true);
408+
verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
409+
verifyNoMoreInteractions(client, clusterService);
247410

248411
service.stop();
249412
}
@@ -256,7 +419,9 @@ public void testMaybeStartAllocation() {
256419
inferenceAuditor,
257420
meterRegistry,
258421
true,
259-
1
422+
1,
423+
60,
424+
60_000
260425
);
261426

262427
when(client.threadPool()).thenReturn(threadPool);
@@ -289,7 +454,9 @@ public void testMaybeStartAllocation_BlocksMultipleRequests() throws Exception {
289454
inferenceAuditor,
290455
meterRegistry,
291456
true,
292-
1
457+
1,
458+
60,
459+
60_000
293460
);
294461

295462
var latch = new CountDownLatch(1);

0 commit comments

Comments
 (0)