diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java index eef560527a0c0..e8e4061395394 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; import org.elasticsearch.xpack.ml.MachineLearning; @@ -213,6 +214,7 @@ Collection observeDouble(Function deploymentIdsWithInFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>(); private final Map lastWarningMessages = new ConcurrentHashMap<>(); @@ -224,7 +226,17 @@ public AdaptiveAllocationsScalerService( MeterRegistry meterRegistry, boolean isNlpEnabled ) { - this(threadPool, clusterService, client, inferenceAuditor, meterRegistry, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS); + this( + threadPool, + clusterService, + client, + inferenceAuditor, + meterRegistry, + isNlpEnabled, + DEFAULT_TIME_INTERVAL_SECONDS, + SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS, + SCALE_UP_COOLDOWN_TIME_MILLIS + ); } // visible for testing @@ -235,7 +247,9 @@ public AdaptiveAllocationsScalerService( InferenceAuditor inferenceAuditor, MeterRegistry meterRegistry, boolean isNlpEnabled, - int timeIntervalSeconds + int timeIntervalSeconds, + long scaleToZeroAfterNoRequestsSeconds, + long scaleUpCooldownTimeMillis ) { this.threadPool = threadPool; this.clusterService = clusterService; @@ -244,6 +258,8 @@ public AdaptiveAllocationsScalerService( this.meterRegistry = meterRegistry; this.isNlpEnabled = isNlpEnabled; this.timeIntervalSeconds = timeIntervalSeconds; + this.scaleToZeroAfterNoRequestsSeconds = scaleToZeroAfterNoRequestsSeconds; + this.scaleUpCooldownTimeMillis = scaleUpCooldownTimeMillis; lastInferenceStatsByDeploymentAndNode = new HashMap<>(); lastInferenceStatsTimestampMillis = null; @@ -251,7 +267,6 @@ public AdaptiveAllocationsScalerService( scalers = new HashMap<>(); metrics = new Metrics(); busy = new AtomicBoolean(false); - scaleToZeroAfterNoRequestsSeconds = SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS; } public synchronized void start() { @@ -375,6 +390,9 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo Map recentStatsByDeployment = new HashMap<>(); Map numberOfAllocations = new HashMap<>(); + // Check for recent scale ups in the deployment stats, because a different node may have + // caused a scale up when an inference request arrives and there were zero allocations. + Set hasRecentObservedScaleUp = new HashSet<>(); for (AssignmentStats assignmentStats : statsResponse.getStats().results()) { String deploymentId = assignmentStats.getDeploymentId(); @@ -401,6 +419,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo (key, value) -> value == null ? recentStats : value.add(recentStats) ); } + if (nodeStats.getRoutingState() != null && nodeStats.getRoutingState().getState() == RoutingState.STARTING) { + hasRecentObservedScaleUp.add(deploymentId); + } + if (nodeStats.getStartTime() != null && now < nodeStats.getStartTime().toEpochMilli() + scaleUpCooldownTimeMillis) { + hasRecentObservedScaleUp.add(deploymentId); + } } } @@ -416,9 +440,12 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale(); if (newNumberOfAllocations != null) { Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId); + // hasRecentScaleUp indicates whether this service has recently scaled up the deployment. + // hasRecentObservedScaleUp indicates whether a deployment recently has started, + // potentially triggered by another node. + boolean hasRecentScaleUp = lastScaleUpTimeMillis != null && now < lastScaleUpTimeMillis + scaleUpCooldownTimeMillis; if (newNumberOfAllocations < numberOfAllocations.get(deploymentId) - && lastScaleUpTimeMillis != null - && now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) { + && (hasRecentScaleUp || hasRecentObservedScaleUp.contains(deploymentId))) { logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId); continue; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java index 62bf8e1c954e9..469df5f45158c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java @@ -36,8 +36,8 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Map; import java.util.Set; @@ -114,7 +114,12 @@ private ClusterState getClusterState(int numAllocations) { return clusterState; } - private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) { + private GetDeploymentStatsAction.Response getDeploymentStatsResponse( + int numAllocations, + int inferenceCount, + double latency, + boolean recentStartup + ) { return new GetDeploymentStatsAction.Response( List.of(), List.of(), @@ -127,7 +132,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo new AdaptiveAllocationsSettings(true, null, null), 1024, ByteSizeValue.ZERO, - Instant.now(), + Instant.now().minus(1, ChronoUnit.DAYS), List.of( AssignmentStats.NodeStats.forStartedState( randomBoolean() ? DiscoveryNodeUtils.create("node_1") : null, @@ -140,7 +145,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo 0, 0, Instant.now(), - Instant.now(), + recentStartup ? Instant.now() : Instant.now().minus(1, ChronoUnit.HOURS), 1, numAllocations, inferenceCount, @@ -156,7 +161,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo ); } - public void test() throws IOException { + public void test_scaleUp() { // Initialize the cluster with a deployment with 1 allocation. ClusterState clusterState = getClusterState(1); when(clusterService.state()).thenReturn(clusterState); @@ -168,7 +173,9 @@ public void test() throws IOException { inferenceAuditor, meterRegistry, true, - 1 + 1, + 60, + 60_000 ); service.start(); @@ -182,7 +189,7 @@ public void test() throws IOException { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0)); + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); @@ -198,7 +205,7 @@ public void test() throws IOException { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0)); + listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -229,7 +236,137 @@ public void test() throws IOException { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0)); + listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + + service.stop(); + } + + public void test_scaleDownToZero_whenNoRequests() { + // Initialize the cluster with a deployment with 1 allocation. + ClusterState clusterState = getClusterState(1); + when(clusterService.state()).thenReturn(clusterState); + + AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( + threadPool, + clusterService, + client, + inferenceAuditor, + meterRegistry, + true, + 1, + 1, + 2_000 + ); + service.start(); + + verify(clusterService).state(); + verify(clusterService).addListener(same(service)); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // First cycle: 1 inference request, so no need for scaling. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(2)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment"); + updateRequest.setNumberOfAllocations(0); + updateRequest.setIsInternal(true); + verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any()); + verifyNoMoreInteractions(client, clusterService); + + service.stop(); + } + + public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() { + // Initialize the cluster with a deployment with 1 allocation. + ClusterState clusterState = getClusterState(1); + when(clusterService.state()).thenReturn(clusterState); + + AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( + threadPool, + clusterService, + client, + inferenceAuditor, + meterRegistry, + true, + 1, + 1, + 2_000 + ); + service.start(); + + verify(clusterService).state(); + verify(clusterService).addListener(same(service)); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // First cycle: 1 inference request, so no need for scaling. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Second cycle: 0 inference requests for 1 second, but a recent scale up by another node. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -244,6 +381,32 @@ public void test() throws IOException { verify(client, times(1)).threadPool(); verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(2)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment"); + updateRequest.setNumberOfAllocations(0); + updateRequest.setIsInternal(true); + verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any()); + verifyNoMoreInteractions(client, clusterService); service.stop(); } @@ -256,7 +419,9 @@ public void testMaybeStartAllocation() { inferenceAuditor, meterRegistry, true, - 1 + 1, + 60, + 60_000 ); when(client.threadPool()).thenReturn(threadPool); @@ -289,7 +454,9 @@ public void testMaybeStartAllocation_BlocksMultipleRequests() throws Exception { inferenceAuditor, meterRegistry, true, - 1 + 1, + 60, + 60_000 ); var latch = new CountDownLatch(1);