Skip to content

Commit 1f8beed

Browse files
authored
[ML] Make AssignmentPlan to consider only assigned allocations (#132170) (#132270)
A follow-up to #131990. This PR ensures that only assigned allocations and not current allocations are used in the memory requirements calculation in AssignmentPlan. This change led to the simplification of the code in ZoneAwareAssignmentPlanner and TrainedModelRebalancer. This PR also improves readability by adding comments, code documentation, renaming variables, and making the flow of if statements more straightforward. Marking is a non-issue since the bug was already documented in #131990.
1 parent 0ee4c18 commit 1f8beed

File tree

6 files changed

+182
-83
lines changed

6 files changed

+182
-83
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,15 @@ private static AssignmentPlan mergePlans(
130130
return finalPlanBuilder.build();
131131
}
132132

133-
private static void copyAssignments(
134-
AssignmentPlan source,
135-
AssignmentPlan.Builder dest,
136-
Map<String, AssignmentPlan.Node> originalNodeById
137-
) {
138-
for (AssignmentPlan.Deployment m : source.deployments()) {
139-
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
140-
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
141-
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
142-
dest.assignModelToNode(m, originalNode, assignment.getValue());
143-
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
144-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
145-
// As the node has all its available memory we need to manually account memory of models with
146-
// current allocations.
147-
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
148-
dest.accountMemory(m, originalNode, requiredMemory);
149-
}
133+
/**
134+
* Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
135+
*/
136+
static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
137+
for (AssignmentPlan.Deployment deployment : source.deployments()) {
138+
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
139+
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
140+
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
141+
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
150142
}
151143
}
152144
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -223,23 +223,43 @@ public int compareTo(AssignmentPlan o) {
223223
return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
224224
}
225225

226+
/**
227+
* Checks whether all deployments in the current {@link AssignmentPlan} have at least as many
228+
* allocations as currently assigned.
229+
*/
226230
public boolean satisfiesCurrentAssignments() {
227231
return deployments().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
228232
}
229233

234+
/**
235+
* Checks whether the current assignments for a given {@link Deployment} meet its allocation requirements.
236+
*
237+
* It ensures that the total number of allocations assigned to the deployment across all nodes is
238+
* at least equal to the deployment's current assigned allocations.
239+
*/
230240
private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
231241
if (m.currentAllocationsByNodeId().isEmpty()) {
232242
return true;
233243
}
234244
Map<Node, Integer> nodeAssignments = assignments.get(m);
235-
int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
236-
return currentAllocations >= m.getCurrentAssignedAllocations();
245+
int inPlanAssignedAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
246+
return inPlanAssignedAllocations >= m.getCurrentAssignedAllocations();
237247
}
238248

239-
public boolean satisfiesAllocations(Deployment m) {
240-
return remainingModelAllocations.getOrDefault(m, 0) == 0;
249+
/**
250+
* Checks if the current assignments satisfy the deployment's allocation requirements.
251+
* @param deployment the deployment to check
252+
* @return true if the current assignments satisfy the deployment's allocation requirements, false otherwise
253+
*/
254+
public boolean satisfiesAllocations(Deployment deployment) {
255+
return remainingModelAllocations.getOrDefault(deployment, 0) == 0;
241256
}
242257

258+
/**
259+
* Checks if the current assignments satisfy all deployments' allocation requirements. This means that
260+
* each deployment has no remaining allocations left to assign.
261+
* @return true if the current assignments satisfy the deployments' allocation requirements, false otherwise
262+
*/
243263
public boolean satisfiesAllModels() {
244264
return deployments().stream().allMatch(this::satisfiesAllocations);
245265
}
@@ -424,8 +444,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
424444
if (allocations <= 0) {
425445
return this;
426446
}
427-
if (/*isAlreadyAssigned(deployment, node) == false
428-
&&*/ requiredMemory > remainingNodeMemory.get(node)) {
447+
if (requiredMemory > remainingNodeMemory.get(node)) {
429448
throw new IllegalArgumentException(
430449
"not enough memory on node ["
431450
+ node.id()
@@ -450,7 +469,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
450469
);
451470
}
452471

453-
assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
472+
assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
454473
accountMemory(deployment, node, requiredMemory);
455474

456475
if (deployment.priority == Priority.NORMAL) {
@@ -461,23 +480,10 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
461480
}
462481

463482
private int getAssignedAllocations(Deployment deployment, Node node) {
464-
int currentAllocations = getCurrentAllocations(deployment, node);
465-
int assignmentAllocations = assignments.get(deployment).get(node);
466-
return currentAllocations + assignmentAllocations;
467-
}
468-
469-
private static int getCurrentAllocations(Deployment m, Node n) {
470-
return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
471-
}
472-
473-
public void accountMemory(Deployment m, Node n) {
474-
// TODO (#101612) remove or refactor unused method
475-
long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
476-
accountMemory(m, n, requiredMemory);
483+
return assignments.get(deployment).get(node);
477484
}
478485

479486
public void accountMemory(Deployment m, Node n, long requiredMemory) {
480-
// TODO (#101612) computation of required memory should be done internally
481487
remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
482488
if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
483489
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.deploymentId() + "]");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,26 @@ public AssignmentPlan computePlan() {
6060
return computePlan(true);
6161
}
6262

63-
public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
63+
/**
64+
* Computes an {@link AssignmentPlan} for the given nodes and deployments.
65+
* If {@code tryAssigningAllPreviouslyAllocatedModels} is true, then the plan will
66+
* attempt to assign at least one allocation to previously assigned models.
67+
* Otherwise, it will only ensure that deployments assigned to existing nodes will preserve at least one allocation
68+
*
69+
* @param tryAssigningAllPreviouslyAllocatedModels whether to do the best effort assigning previously assigned models somewhere
70+
* with at least one allocation
71+
* @return the computed assignment plan
72+
*/
73+
public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
6474
logger.debug(() -> format("Computing plan for nodes = %s; deployments = %s", nodes, deployments));
6575

6676
AssignmentPlan bestPlan;
6777
AssignmentPlan planSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
6878
logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
69-
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() == false && tryAssigningPreviouslyAssignedModels) {
79+
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || tryAssigningAllPreviouslyAllocatedModels == false) {
80+
bestPlan = planSatisfyingCurrentAssignments;
81+
} else {
82+
// try to reuse any deployment that would otherwise drop to zero allocations
7083
AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated =
7184
solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
7285
logger.debug(
@@ -82,28 +95,37 @@ public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels)
8295
? planSatisfyingCurrentAssignments
8396
: planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
8497
}
85-
} else {
86-
bestPlan = planSatisfyingCurrentAssignments;
8798
}
8899

89100
logger.debug(() -> "Best plan =\n" + bestPlan.prettyPrint());
90101
logger.debug(() -> prettyPrintOverallStats(bestPlan));
91102
return bestPlan;
92103
}
93104

105+
/**
106+
* Computes the best assignment plan from two strategies:
107+
* 1. Preserving one allocation on current assignments, which is the most flexible
108+
* 2. Preserving all allocations on current assignments, which is more conservative
109+
* @return the best assignment plan
110+
*/
94111
private AssignmentPlan solveSatisfyingCurrentAssignments() {
95112
AssignmentPlan bestPlan;
96113
// First solve preserving one allocation per assignment because that is most flexible
97114
AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
98-
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
115+
116+
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
117+
// If the plan satisfies all models, then we can use it as is
118+
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
119+
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
120+
// If in the new assignment plan, some deployments have fewer allocations than in the current assignments,
121+
// try explicitly preserving all allocations on current assignments.
99122
bestPlan = solvePreservingAllAllocationsOnCurrentAssignments();
100-
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() == false) {
123+
} else {
124+
// Choose the best strategy according to {@link AssignmentPlan#computeQuality(AssignmentPlan)}
101125
AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
102126
bestPlan = planKeepingAllAllocationsOnCurrentAssignments.compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0
103127
? planKeepingAllAllocationsOnCurrentAssignments
104128
: planKeepingOneAllocationOnCurrentAssignments;
105-
} else {
106-
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
107129
}
108130
return bestPlan;
109131
}
@@ -120,7 +142,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
120142
1,
121143
m.threadsPerAllocation(),
122144
// don't rely on the current allocation
123-
new HashMap<>(),
145+
Map.of(),
124146
m.maxAssignedAllocations(),
125147
m.getAdaptiveAllocationsSettings(),
126148
m.perDeploymentMemoryBytes(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,36 +182,39 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
182182
List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
183183
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
184184
plan = preserveAllAllocations.mergePreservedAllocations(plan);
185-
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
185+
return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
186186
}
187187

188-
private AssignmentPlan swapOriginalModelsInPlan(
188+
/**
189+
* The method is responsible for reconstructing an AssignmentPlan
190+
* by replacing the deployments and nodes in the given plan with their original counterparts.
191+
* This ensures that the final plan uses the original objects while preserving the assignments
192+
* and memory accounting from the input plan.
193+
*
194+
* @param plan AssignmentPlan to reconstruct with original models and nodes
195+
* @param allNodes List of all nodes in the system, used to find original nodes
196+
* @param planDeployments List of deployments in the plan, not the original deployments
197+
* @return final plan with original models and nodes swapped in
198+
*/
199+
private AssignmentPlan swapOriginalDeploymentsInPlan(
189200
AssignmentPlan plan,
190201
List<Node> allNodes,
191202
List<AssignmentPlan.Deployment> planDeployments
192203
) {
193-
final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
204+
final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
194205
.collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, Function.identity()));
195206
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
196-
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
197-
for (AssignmentPlan.Deployment m : planDeployments) {
198-
AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.deploymentId());
199-
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
207+
AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);
208+
209+
for (AssignmentPlan.Deployment planDeployment : planDeployments) {
210+
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId());
211+
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
200212
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
201213
Node originalNode = originalNodeById.get(assignment.getKey().id());
202-
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
203-
if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
204-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
205-
// As the node has all its available memory we need to manually account memory of models with
206-
// current allocations.
207-
long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
208-
originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
209-
);
210-
planBuilder.accountMemory(m, originalNode, requiredMemory);
211-
}
214+
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
212215
}
213216
}
214-
return planBuilder.build();
217+
return finalPlanBuilder.build();
215218
}
216219

217220
private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
2222
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
2323
import org.elasticsearch.xpack.ml.MachineLearning;
24+
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
2425
import org.elasticsearch.xpack.ml.job.NodeLoad;
2526

2627
import java.util.ArrayList;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Optional;
32+
import java.util.function.Function;
33+
import java.util.stream.Collectors;
3134

3235
import static org.hamcrest.Matchers.aMapWithSize;
3336
import static org.hamcrest.Matchers.anEmptyMap;
@@ -1127,6 +1130,74 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() {
11271130
assertThat(assignment.getReason().isPresent(), is(false));
11281131
}
11291132

1133+
public void testCopyAssignments() {
1134+
// Create test nodes
1135+
AssignmentPlan.Node node1 = new AssignmentPlan.Node("node-1", ByteSizeValue.ofGb(1).getBytes(), 4);
1136+
AssignmentPlan.Node node2 = new AssignmentPlan.Node("node-2", ByteSizeValue.ofGb(1).getBytes(), 8);
1137+
List<AssignmentPlan.Node> nodes = List.of(node1, node2);
1138+
1139+
// Create test deployments
1140+
AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
1141+
"deployment-1",
1142+
"model-1",
1143+
ByteSizeValue.ofMb(100).getBytes(),
1144+
2,
1145+
1,
1146+
Map.of(),
1147+
0,
1148+
null,
1149+
Priority.NORMAL,
1150+
0,
1151+
0
1152+
);
1153+
AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
1154+
"deployment-2",
1155+
"model-2",
1156+
ByteSizeValue.ofMb(100).getBytes(),
1157+
1,
1158+
2,
1159+
Map.of(),
1160+
0,
1161+
null,
1162+
Priority.LOW,
1163+
0,
1164+
0
1165+
);
1166+
List<AssignmentPlan.Deployment> deployments = List.of(deployment1, deployment2);
1167+
1168+
// Create source plan and assign models to nodes
1169+
AssignmentPlan.Builder sourceBuilder = AssignmentPlan.builder(nodes, deployments);
1170+
sourceBuilder.assignModelToNode(deployment1, node1, 1);
1171+
sourceBuilder.assignModelToNode(deployment1, node2, 1);
1172+
sourceBuilder.assignModelToNode(deployment2, node2, 1);
1173+
AssignmentPlan source = sourceBuilder.build();
1174+
1175+
// Create destination plan
1176+
AssignmentPlan.Builder dest = AssignmentPlan.builder(nodes, deployments);
1177+
1178+
// Create map of node IDs to original nodes
1179+
Map<String, AssignmentPlan.Node> originalNodeById = nodes.stream()
1180+
.collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
1181+
1182+
// Call copyAssignments
1183+
TrainedModelAssignmentRebalancer.copyAssignments(source, dest, originalNodeById);
1184+
1185+
// Build the destination plan
1186+
AssignmentPlan result = dest.build();
1187+
1188+
// Verify assignments
1189+
Optional<Map<AssignmentPlan.Node, Integer>> deployment1Assignments = result.assignments(deployment1);
1190+
assertThat(deployment1Assignments.isPresent(), is(true));
1191+
assertThat(deployment1Assignments.get().size(), equalTo(2));
1192+
assertThat(deployment1Assignments.get().get(node1), equalTo(1));
1193+
assertThat(deployment1Assignments.get().get(node2), equalTo(1));
1194+
1195+
Optional<Map<AssignmentPlan.Node, Integer>> deployment2Assignments = result.assignments(deployment2);
1196+
assertThat(deployment2Assignments.isPresent(), is(true));
1197+
assertThat(deployment2Assignments.get().size(), equalTo(1));
1198+
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
1199+
}
1200+
11301201
private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
11311202
return lowPriorityParams(deploymentId, deploymentId, modelSize);
11321203
}

0 commit comments

Comments
 (0)