Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.MachineLearning;
Expand Down Expand Up @@ -213,6 +214,7 @@ Collection<DoubleWithAttributes> observeDouble(Function<AdaptiveAllocationsScale
private volatile Scheduler.Cancellable cancellable;
private final AtomicBoolean busy;
private final long scaleToZeroAfterNoRequestsSeconds;
private final long scaleUpCooldownTimeMillis;
private final Set<String> deploymentIdsWithInFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>();
private final Map<String, String> lastWarningMessages = new ConcurrentHashMap<>();

Expand All @@ -224,7 +226,17 @@ public AdaptiveAllocationsScalerService(
MeterRegistry meterRegistry,
boolean isNlpEnabled
) {
this(threadPool, clusterService, client, inferenceAuditor, meterRegistry, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS);
this(
threadPool,
clusterService,
client,
inferenceAuditor,
meterRegistry,
isNlpEnabled,
DEFAULT_TIME_INTERVAL_SECONDS,
SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS,
SCALE_UP_COOLDOWN_TIME_MILLIS
);
}

// visible for testing
Expand All @@ -235,7 +247,9 @@ public AdaptiveAllocationsScalerService(
InferenceAuditor inferenceAuditor,
MeterRegistry meterRegistry,
boolean isNlpEnabled,
int timeIntervalSeconds
int timeIntervalSeconds,
long scaleToZeroAfterNoRequestsSeconds,
long scaleUpCooldownTimeMillis
) {
this.threadPool = threadPool;
this.clusterService = clusterService;
Expand All @@ -244,14 +258,15 @@ public AdaptiveAllocationsScalerService(
this.meterRegistry = meterRegistry;
this.isNlpEnabled = isNlpEnabled;
this.timeIntervalSeconds = timeIntervalSeconds;
this.scaleToZeroAfterNoRequestsSeconds = scaleToZeroAfterNoRequestsSeconds;
this.scaleUpCooldownTimeMillis = scaleUpCooldownTimeMillis;

lastInferenceStatsByDeploymentAndNode = new HashMap<>();
lastInferenceStatsTimestampMillis = null;
lastScaleUpTimesMillis = new HashMap<>();
scalers = new HashMap<>();
metrics = new Metrics();
busy = new AtomicBoolean(false);
scaleToZeroAfterNoRequestsSeconds = SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS;
}

public synchronized void start() {
Expand Down Expand Up @@ -375,6 +390,9 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo

Map<String, Stats> recentStatsByDeployment = new HashMap<>();
Map<String, Integer> numberOfAllocations = new HashMap<>();
// Check for recent scale ups in the deployment stats, because a different node may have
// caused a scale up when an inference request arrives and there were zero allocations.
Set<String> hasRecentObservedScaleUp = new HashSet<>();

for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
String deploymentId = assignmentStats.getDeploymentId();
Expand All @@ -401,6 +419,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
(key, value) -> value == null ? recentStats : value.add(recentStats)
);
}
if (nodeStats.getRoutingState() != null && nodeStats.getRoutingState().getState() == RoutingState.STARTING) {
hasRecentObservedScaleUp.add(deploymentId);
}
if (nodeStats.getStartTime() != null && now < nodeStats.getStartTime().toEpochMilli() + scaleUpCooldownTimeMillis) {
hasRecentObservedScaleUp.add(deploymentId);
}
}
}

Expand All @@ -416,9 +440,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale();
if (newNumberOfAllocations != null) {
Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId);
// hasRecentScaleUp indicates whether this service has recently scaled up the deployment.
// hasRecentObservedScaleUp indicates whether a deployment recently has started,
// potentially triggered by another node.
boolean hasRecentScaleUp = lastScaleUpTimeMillis != null && now < lastScaleUpTimeMillis + scaleUpCooldownTimeMillis;
if (newNumberOfAllocations < numberOfAllocations.get(deploymentId)
&& lastScaleUpTimeMillis != null
&& now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) {
&& (hasRecentScaleUp || hasRecentObservedScaleUp.contains(deploymentId))) {
logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -114,7 +114,12 @@ private ClusterState getClusterState(int numAllocations) {
return clusterState;
}

private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) {
private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
int numAllocations,
int inferenceCount,
double latency,
boolean recentStartup
) {
return new GetDeploymentStatsAction.Response(
List.of(),
List.of(),
Expand All @@ -127,7 +132,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
new AdaptiveAllocationsSettings(true, null, null),
1024,
ByteSizeValue.ZERO,
Instant.now(),
Instant.now().minus(1, ChronoUnit.DAYS),
List.of(
AssignmentStats.NodeStats.forStartedState(
randomBoolean() ? DiscoveryNodeUtils.create("node_1") : null,
Expand All @@ -140,7 +145,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
0,
0,
Instant.now(),
Instant.now(),
recentStartup ? Instant.now() : Instant.now().minus(1, ChronoUnit.HOURS),
1,
numAllocations,
inferenceCount,
Expand All @@ -156,7 +161,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
);
}

public void test() throws IOException {
public void test_scaleUp() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
when(clusterService.state()).thenReturn(clusterState);
Expand All @@ -168,7 +173,9 @@ public void test() throws IOException {
inferenceAuditor,
meterRegistry,
true,
1
1,
60,
60_000
);
service.start();

Expand All @@ -182,7 +189,7 @@ public void test() throws IOException {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0));
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

Expand All @@ -198,7 +205,7 @@ public void test() throws IOException {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0));
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand Down Expand Up @@ -229,7 +236,137 @@ public void test() throws IOException {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0));
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(null);
return Void.TYPE;
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());

safeSleep(1000);

verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);

service.stop();
}

public void test_scaleDownToZero_whenNoRequests() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
threadPool,
clusterService,
client,
inferenceAuditor,
meterRegistry,
true,
1,
1,
2_000
);
service.start();

verify(clusterService).state();
verify(clusterService).addListener(same(service));
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// First cycle: 1 inference request, so no need for scaling.
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

safeSleep(1200);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this waiting for the doAnswer() call? I'm not sure if this will be flaky but another option could be to add a count down latch and then here we await for the latch to be decremented.

Copy link
Contributor Author

@jan-elastic jan-elastic Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is waiting for the adaptive allocations background process that makes these calls. That runs every 1s in this test (see constructor above). If within 1200ms nothing has happened, something is wrong.

With 200ms leeway it shouldn't be flaky, and never has been in the past.

I've also run these tests 100x locally without any problems.


verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations.
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(null);
return Void.TYPE;
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());

safeSleep(1000);

verify(client, times(2)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
updateRequest.setNumberOfAllocations(0);
updateRequest.setIsInternal(true);
verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
verifyNoMoreInteractions(client, clusterService);

service.stop();
}

public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
threadPool,
clusterService,
client,
inferenceAuditor,
meterRegistry,
true,
1,
1,
2_000
);
service.start();

verify(clusterService).state();
verify(clusterService).addListener(same(service));
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// First cycle: 1 inference request, so no need for scaling.
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

safeSleep(1200);

verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// Second cycle: 0 inference requests for 1 second, but a recent scale up by another node.
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand All @@ -244,6 +381,32 @@ public void test() throws IOException {
verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations.
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(null);
return Void.TYPE;
}).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());

safeSleep(1000);

verify(client, times(2)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
updateRequest.setNumberOfAllocations(0);
updateRequest.setIsInternal(true);
verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
verifyNoMoreInteractions(client, clusterService);

service.stop();
}
Expand All @@ -256,7 +419,9 @@ public void testMaybeStartAllocation() {
inferenceAuditor,
meterRegistry,
true,
1
1,
60,
60_000
);

when(client.threadPool()).thenReturn(threadPool);
Expand Down Expand Up @@ -289,7 +454,9 @@ public void testMaybeStartAllocation_BlocksMultipleRequests() throws Exception {
inferenceAuditor,
meterRegistry,
true,
1
1,
60,
60_000
);

var latch = new CountDownLatch(1);
Expand Down