diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index 249e27d6f25e0..e8621a68e5c22 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -144,7 +144,7 @@ private TrainedModelAssignment( * @param assignmentState used to track the state of the assignment for rebalancing, autoscaling, and more * @param reason may contain a human-readable explanation for the current state * @param startTime the time when the assignment was created - * @param maxAssignedAllocations used for adaptive allocations + * @param maxAssignedAllocations keeps track of the maximum number of allocations used for this assignment * @param adaptiveAllocationsSettings how the assignment should scale based on usage */ TrainedModelAssignment( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java index e8e4061395394..0bb7256ee6143 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; @@ -390,6 +391,7 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo Map recentStatsByDeployment = new HashMap<>(); Map numberOfAllocations = new HashMap<>(); + Map assignmentStates = new HashMap<>(); // Check for recent scale ups in the deployment stats, because a different node may have // caused a scale up when an inference request arrives and there were zero allocations. Set hasRecentObservedScaleUp = new HashSet<>(); @@ -397,6 +399,7 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo for (AssignmentStats assignmentStats : statsResponse.getStats().results()) { String deploymentId = assignmentStats.getDeploymentId(); numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations()); + assignmentStates.put(deploymentId, assignmentStats.getState()); Map deploymentStats = lastInferenceStatsByDeploymentAndNode.computeIfAbsent( deploymentId, key -> new HashMap<>() @@ -449,6 +452,14 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId); continue; } + if (assignmentStates.get(deploymentId) != AssignmentState.STARTED) { + logger.debug( + "adaptive allocations scaler: skipping scaling [{}] because it is in [{}] state.", + deploymentId, + assignmentStates.get(deploymentId) + ); + continue; + } if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) { lastScaleUpTimesMillis.put(deploymentId, now); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index 8b5f33e25e242..8b33a9fd54fad 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -50,7 +50,10 @@ public class AssignmentPlanner { public AssignmentPlanner(List nodes, List deployments) { this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList(); - this.deployments = deployments.stream().sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList(); + this.deployments = deployments.stream() + .filter(deployment -> deployment.allocations() > 0) + .sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)) + .toList(); } public AssignmentPlan computePlan() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java index 469df5f45158c..952a228f004a6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; @@ -85,7 +86,7 @@ public void tearDown() throws Exception { super.tearDown(); } - private ClusterState getClusterState(int numAllocations) { + private ClusterState getClusterState(int numAllocations, AssignmentState assignmentState) { ClusterState clusterState = mock(ClusterState.class); Metadata metadata = mock(Metadata.class); when(clusterState.metadata()).thenReturn(metadata); @@ -107,7 +108,7 @@ private ClusterState getClusterState(int numAllocations) { 100_000_000 ), new AdaptiveAllocationsSettings(true, null, null) - ).build() + ).setAssignmentState(assignmentState).build() ) ) ); @@ -118,7 +119,8 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse( int numAllocations, int inferenceCount, double latency, - boolean recentStartup + boolean recentStartup, + AssignmentState assignmentState ) { return new GetDeploymentStatsAction.Response( List.of(), @@ -155,7 +157,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse( ) ), Priority.NORMAL - ) + ).setState(assignmentState) ), 0 ); @@ -163,7 +165,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse( public void test_scaleUp() { // Initialize the cluster with a deployment with 1 allocation. - ClusterState clusterState = getClusterState(1); + ClusterState clusterState = getClusterState(1, AssignmentState.STARTED); when(clusterService.state()).thenReturn(clusterState); AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( @@ -189,7 +191,7 @@ public void test_scaleUp() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false)); + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); @@ -205,7 +207,7 @@ public void test_scaleUp() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false)); + listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -226,7 +228,7 @@ public void test_scaleUp() { verifyNoMoreInteractions(client, clusterService); reset(client, clusterService); - clusterState = getClusterState(2); + clusterState = getClusterState(2, AssignmentState.STARTED); ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class); when(clusterChangedEvent.state()).thenReturn(clusterState); service.clusterChanged(clusterChangedEvent); @@ -236,7 +238,7 @@ public void test_scaleUp() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false)); + listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -257,7 +259,7 @@ public void test_scaleUp() { public void test_scaleDownToZero_whenNoRequests() { // Initialize the cluster with a deployment with 1 allocation. - ClusterState clusterState = getClusterState(1); + ClusterState clusterState = getClusterState(1, AssignmentState.STARTED); when(clusterService.state()).thenReturn(clusterState); AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( @@ -283,7 +285,7 @@ public void test_scaleDownToZero_whenNoRequests() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false)); + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); @@ -299,7 +301,7 @@ public void test_scaleDownToZero_whenNoRequests() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false)); + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -322,9 +324,65 @@ public void test_scaleDownToZero_whenNoRequests() { service.stop(); } + public void test_dontScale_whenNotStarted() { + // Initialize the cluster with a deployment with 1 allocation. + ClusterState clusterState = getClusterState(1, AssignmentState.STARTING); + when(clusterService.state()).thenReturn(clusterState); + + AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( + threadPool, + clusterService, + client, + inferenceAuditor, + meterRegistry, + true, + 1, + 1, + 2_000 + ); + service.start(); + + verify(clusterService).state(); + verify(clusterService).addListener(same(service)); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // First cycle: many inference requests + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 10000, 10.0, false, AssignmentState.STARTING)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Second cycle: again many inference requests + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 20000, 10.0, false, AssignmentState.STARTING)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + service.stop(); + } + public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() { // Initialize the cluster with a deployment with 1 allocation. - ClusterState clusterState = getClusterState(1); + ClusterState clusterState = getClusterState(1, AssignmentState.STARTED); when(clusterService.state()).thenReturn(clusterState); AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( @@ -350,7 +408,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true)); + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); @@ -366,7 +424,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true)); + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> { @@ -388,7 +446,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() { doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false)); + listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED)); return Void.TYPE; }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); doAnswer(invocationOnMock -> {