Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private TrainedModelAssignment(
* @param assignmentState used to track the state of the assignment for rebalancing, autoscaling, and more
* @param reason may contain a human-readable explanation for the current state
* @param startTime the time when the assignment was created
* @param maxAssignedAllocations used for adaptive allocations
* @param maxAssignedAllocations keeps track of the maximum number of allocations used for this assignment
* @param adaptiveAllocationsSettings how the assignment should scale based on usage
*/
TrainedModelAssignment(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
Expand Down Expand Up @@ -389,13 +390,15 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo

Map<String, Stats> recentStatsByDeployment = new HashMap<>();
Map<String, Integer> numberOfAllocations = new HashMap<>();
Map<String, AssignmentState> assignmentStates = new HashMap<>();
// Check for recent scale ups in the deployment stats, because a different node may have
// caused a scale up when an inference request arrives and there were zero allocations.
Set<String> hasRecentObservedScaleUp = new HashSet<>();

for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
String deploymentId = assignmentStats.getDeploymentId();
numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations());
assignmentStates.put(deploymentId, assignmentStats.getState());
Map<String, Stats> deploymentStats = lastInferenceStatsByDeploymentAndNode.computeIfAbsent(
deploymentId,
key -> new HashMap<>()
Expand Down Expand Up @@ -447,6 +450,14 @@ private void processDeploymentStats(GetDeploymentStatsAction.Response statsRespo
logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId);
continue;
}
if (assignmentStates.get(deploymentId) != AssignmentState.STARTED) {
logger.debug(
"adaptive allocations scaler: skipping scaling [{}] because it is in [{}] state.",
deploymentId,
assignmentStates.get(deploymentId)
);
continue;
}
if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) {
lastScaleUpTimesMillis.put(deploymentId, now);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public class AssignmentPlanner {

public AssignmentPlanner(List<Node> nodes, List<AssignmentPlan.Deployment> deployments) {
this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList();
this.deployments = deployments.stream().sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList();
this.deployments = deployments.stream()
.filter(deployment -> deployment.allocations() > 0)
.sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId))
.toList();
}

public AssignmentPlan computePlan() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
Expand Down Expand Up @@ -85,7 +86,7 @@ public void tearDown() throws Exception {
super.tearDown();
}

private ClusterState getClusterState(int numAllocations) {
private ClusterState getClusterState(int numAllocations, AssignmentState assignmentState) {
ClusterState clusterState = mock(ClusterState.class);
Metadata metadata = mock(Metadata.class);
when(clusterState.getMetadata()).thenReturn(metadata);
Expand All @@ -107,7 +108,7 @@ private ClusterState getClusterState(int numAllocations) {
100_000_000
),
new AdaptiveAllocationsSettings(true, null, null)
).build()
).setAssignmentState(assignmentState).build()
)
)
);
Expand All @@ -118,7 +119,8 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
int numAllocations,
int inferenceCount,
double latency,
boolean recentStartup
boolean recentStartup,
AssignmentState assignmentState
) {
return new GetDeploymentStatsAction.Response(
List.of(),
Expand Down Expand Up @@ -155,15 +157,15 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
)
),
Priority.NORMAL
)
).setState(assignmentState)
),
0
);
}

public void test_scaleUp() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
Expand All @@ -189,7 +191,7 @@ public void test_scaleUp() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

Expand All @@ -205,7 +207,7 @@ public void test_scaleUp() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false));
listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand All @@ -226,7 +228,7 @@ public void test_scaleUp() {
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

clusterState = getClusterState(2);
clusterState = getClusterState(2, AssignmentState.STARTED);
ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class);
when(clusterChangedEvent.state()).thenReturn(clusterState);
service.clusterChanged(clusterChangedEvent);
Expand All @@ -236,7 +238,7 @@ public void test_scaleUp() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false));
listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand All @@ -257,7 +259,7 @@ public void test_scaleUp() {

public void test_scaleDownToZero_whenNoRequests() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
Expand All @@ -283,7 +285,7 @@ public void test_scaleDownToZero_whenNoRequests() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

Expand All @@ -299,7 +301,7 @@ public void test_scaleDownToZero_whenNoRequests() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand All @@ -322,9 +324,65 @@ public void test_scaleDownToZero_whenNoRequests() {
service.stop();
}

public void test_dontScale_whenNotStarted() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1, AssignmentState.STARTING);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
threadPool,
clusterService,
client,
inferenceAuditor,
meterRegistry,
true,
1,
1,
2_000
);
service.start();

verify(clusterService).state();
verify(clusterService).addListener(same(service));
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// First cycle: many inference requests
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 10000, 10.0, false, AssignmentState.STARTING));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

safeSleep(1200);

verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);
reset(client, clusterService);

// Second cycle: again many inference requests
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 20000, 10.0, false, AssignmentState.STARTING));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

safeSleep(1200);

verify(client, times(1)).threadPool();
verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
verifyNoMoreInteractions(client, clusterService);
service.stop();
}

public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
// Initialize the cluster with a deployment with 1 allocation.
ClusterState clusterState = getClusterState(1);
ClusterState clusterState = getClusterState(1, AssignmentState.STARTED);
when(clusterService.state()).thenReturn(clusterState);

AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
Expand All @@ -350,7 +408,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true));
listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());

Expand All @@ -366,7 +424,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true));
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand All @@ -388,7 +446,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false, AssignmentState.STARTED));
return Void.TYPE;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
doAnswer(invocationOnMock -> {
Expand Down