Skip to content

Commit 727d5e7

Browse files
authored
[8.18] [ML] Make AssignmentPlan to consider only assigned allocations (#132170) (#132275)
* [ML] Make AssignmentPlan to consider only assigned allocations (#132170) A follow-up to #131990. This PR ensures that only assigned allocations and not current allocations are used in the memory requirements calculation in AssignmentPlan. This change led to the simplification of the code in ZoneAwareAssignmentPlanner and TrainedModelRebalancer. This PR also improves readability by adding comments, code documentation, renaming variables, and making the flow of if statements more straightforward. Marking is a non-issue since the bug was already documented in #131990. (cherry picked from commit 80c47f3) * Fix backport errors * Fix unit test
1 parent 6c7d0a9 commit 727d5e7

File tree

6 files changed

+182
-84
lines changed

6 files changed

+182
-84
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,15 @@ private static AssignmentPlan mergePlans(
134134
return finalPlanBuilder.build();
135135
}
136136

137-
private static void copyAssignments(
138-
AssignmentPlan source,
139-
AssignmentPlan.Builder dest,
140-
Map<String, AssignmentPlan.Node> originalNodeById
141-
) {
142-
for (AssignmentPlan.Deployment m : source.models()) {
143-
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
144-
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
145-
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
146-
dest.assignModelToNode(m, originalNode, assignment.getValue());
147-
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
148-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
149-
// As the node has all its available memory we need to manually account memory of models with
150-
// current allocations.
151-
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
152-
dest.accountMemory(m, originalNode, requiredMemory);
153-
}
137+
/**
138+
* Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
139+
*/
140+
static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
141+
for (AssignmentPlan.Deployment deployment : source.models()) {
142+
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
143+
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
144+
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
145+
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
154146
}
155147
}
156148
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -207,23 +207,43 @@ public int compareTo(AssignmentPlan o) {
207207
return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
208208
}
209209

210+
/**
211+
* Checks whether all deployments in the current {@link AssignmentPlan} have at least as many
212+
* allocations as currently assigned.
213+
*/
210214
public boolean satisfiesCurrentAssignments() {
211215
return models().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
212216
}
213217

218+
/**
219+
* Checks whether the current assignments for a given {@link Deployment} meet its allocation requirements.
220+
*
221+
* It ensures that the total number of allocations assigned to the deployment across all nodes is
222+
* at least equal to the deployment's current assigned allocations.
223+
*/
214224
private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
215225
if (m.currentAllocationsByNodeId().isEmpty()) {
216226
return true;
217227
}
218228
Map<Node, Integer> nodeAssignments = assignments.get(m);
219-
int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
220-
return currentAllocations >= m.getCurrentAssignedAllocations();
229+
int inPlanAssignedAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
230+
return inPlanAssignedAllocations >= m.getCurrentAssignedAllocations();
221231
}
222232

223-
public boolean satisfiesAllocations(Deployment m) {
224-
return remainingModelAllocations.getOrDefault(m, 0) == 0;
233+
/**
234+
* Checks if the current assignments satisfy the deployment's allocation requirements.
235+
* @param deployment the deployment to check
236+
* @return true if the current assignments satisfy the deployment's allocation requirements, false otherwise
237+
*/
238+
public boolean satisfiesAllocations(Deployment deployment) {
239+
return remainingModelAllocations.getOrDefault(deployment, 0) == 0;
225240
}
226241

242+
/**
243+
* Checks if the current assignments satisfy all deployments' allocation requirements. This means that
244+
* each deployment has no remaining allocations left to assign.
245+
* @return true if the current assignments satisfy the deployments' allocation requirements, false otherwise
246+
*/
227247
public boolean satisfiesAllModels() {
228248
return models().stream().allMatch(this::satisfiesAllocations);
229249
}
@@ -408,8 +428,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
408428
if (allocations <= 0) {
409429
return this;
410430
}
411-
if (/*isAlreadyAssigned(deployment, node) == false
412-
&&*/ requiredMemory > remainingNodeMemory.get(node)) {
431+
if (requiredMemory > remainingNodeMemory.get(node)) {
413432
throw new IllegalArgumentException(
414433
"not enough memory on node ["
415434
+ node.id()
@@ -434,7 +453,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
434453
);
435454
}
436455

437-
assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
456+
assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
438457
accountMemory(deployment, node, requiredMemory);
439458

440459
if (deployment.priority == Priority.NORMAL) {
@@ -445,23 +464,10 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
445464
}
446465

447466
private int getAssignedAllocations(Deployment deployment, Node node) {
448-
int currentAllocations = getCurrentAllocations(deployment, node);
449-
int assignmentAllocations = assignments.get(deployment).get(node);
450-
return currentAllocations + assignmentAllocations;
451-
}
452-
453-
private static int getCurrentAllocations(Deployment m, Node n) {
454-
return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
455-
}
456-
457-
public void accountMemory(Deployment m, Node n) {
458-
// TODO (#101612) remove or refactor unused method
459-
long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
460-
accountMemory(m, n, requiredMemory);
467+
return assignments.get(deployment).get(node);
461468
}
462469

463-
public void accountMemory(Deployment m, Node n, long requiredMemory) {
464-
// TODO (#101612) computation of required memory should be done internally
470+
void accountMemory(Deployment m, Node n, long requiredMemory) {
465471
remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
466472
if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
467473
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,26 @@ public AssignmentPlan computePlan() {
5757
return computePlan(true);
5858
}
5959

60-
public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
60+
/**
61+
* Computes an {@link AssignmentPlan} for the given nodes and deployments.
62+
* If {@code tryAssigningAllPreviouslyAllocatedModels} is true, then the plan will
63+
* attempt to assign at least one allocation to previously assigned models.
64+
* Otherwise, it will only ensure that deployments assigned to existing nodes will preserve at least one allocation
65+
*
66+
* @param tryAssigningAllPreviouslyAllocatedModels whether to do the best effort assigning previously assigned models somewhere
67+
* with at least one allocation
68+
* @return the computed assignment plan
69+
*/
70+
public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
6171
logger.debug(() -> format("Computing plan for nodes = %s; deployments = %s", nodes, deployments));
6272

6373
AssignmentPlan bestPlan;
6474
AssignmentPlan planSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
6575
logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
66-
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() == false && tryAssigningPreviouslyAssignedModels) {
76+
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || tryAssigningAllPreviouslyAllocatedModels == false) {
77+
bestPlan = planSatisfyingCurrentAssignments;
78+
} else {
79+
// try to reuse any deployment that would otherwise drop to zero allocations
6780
AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated =
6881
solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
6982
logger.debug(
@@ -79,28 +92,37 @@ public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels)
7992
? planSatisfyingCurrentAssignments
8093
: planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
8194
}
82-
} else {
83-
bestPlan = planSatisfyingCurrentAssignments;
8495
}
8596

8697
logger.debug(() -> "Best plan =\n" + bestPlan.prettyPrint());
8798
logger.debug(() -> prettyPrintOverallStats(bestPlan));
8899
return bestPlan;
89100
}
90101

102+
/**
103+
* Computes the best assignment plan from two strategies:
104+
* 1. Preserving one allocation on current assignments, which is the most flexible
105+
* 2. Preserving all allocations on current assignments, which is more conservative
106+
* @return the best assignment plan
107+
*/
91108
private AssignmentPlan solveSatisfyingCurrentAssignments() {
92109
AssignmentPlan bestPlan;
93110
// First solve preserving one allocation per assignment because that is most flexible
94111
AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
95-
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
112+
113+
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
114+
// If the plan satisfies all models, then we can use it as is
115+
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
116+
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
117+
// If in the new assignment plan, some deployments have fewer allocations than in the current assignments,
118+
// try explicitly preserving all allocations on current assignments.
96119
bestPlan = solvePreservingAllAllocationsOnCurrentAssignments();
97-
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() == false) {
120+
} else {
121+
// Choose the best strategy according to {@link AssignmentPlan#computeQuality(AssignmentPlan)}
98122
AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
99123
bestPlan = planKeepingAllAllocationsOnCurrentAssignments.compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0
100124
? planKeepingAllAllocationsOnCurrentAssignments
101125
: planKeepingOneAllocationOnCurrentAssignments;
102-
} else {
103-
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
104126
}
105127
return bestPlan;
106128
}
@@ -116,7 +138,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
116138
1,
117139
m.threadsPerAllocation(),
118140
// don't rely on the current allocation
119-
new HashMap<>(),
141+
Map.of(),
120142
m.maxAssignedAllocations(),
121143
m.getAdaptiveAllocationsSettings(),
122144
m.perDeploymentMemoryBytes(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,36 +167,39 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
167167
List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
168168
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
169169
plan = preserveAllAllocations.mergePreservedAllocations(plan);
170-
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
170+
return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
171171
}
172172

173-
private AssignmentPlan swapOriginalModelsInPlan(
173+
/**
174+
* The method is responsible for reconstructing an AssignmentPlan
175+
* by replacing the deployments and nodes in the given plan with their original counterparts.
176+
* This ensures that the final plan uses the original objects while preserving the assignments
177+
* and memory accounting from the input plan.
178+
*
179+
* @param plan AssignmentPlan to reconstruct with original models and nodes
180+
* @param allNodes List of all nodes in the system, used to find original nodes
181+
* @param planDeployments List of deployments in the plan, not the original deployments
182+
* @return final plan with original models and nodes swapped in
183+
*/
184+
private AssignmentPlan swapOriginalDeploymentsInPlan(
174185
AssignmentPlan plan,
175186
List<Node> allNodes,
176187
List<AssignmentPlan.Deployment> planDeployments
177188
) {
178-
final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
189+
final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
179190
.collect(Collectors.toMap(AssignmentPlan.Deployment::id, Function.identity()));
180191
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
181-
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
182-
for (AssignmentPlan.Deployment m : planDeployments) {
183-
AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.id());
184-
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
192+
AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);
193+
194+
for (AssignmentPlan.Deployment planDeployment : planDeployments) {
195+
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.id());
196+
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
185197
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
186198
Node originalNode = originalNodeById.get(assignment.getKey().id());
187-
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
188-
if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
189-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
190-
// As the node has all its available memory we need to manually account memory of models with
191-
// current allocations.
192-
long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
193-
originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
194-
);
195-
planBuilder.accountMemory(m, originalNode, requiredMemory);
196-
}
199+
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
197200
}
198201
}
199-
return planBuilder.build();
202+
return finalPlanBuilder.build();
200203
}
201204

202205
private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByModelId(List<AssignmentPlan> plans) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
2222
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
2323
import org.elasticsearch.xpack.ml.MachineLearning;
24+
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
2425
import org.elasticsearch.xpack.ml.job.NodeLoad;
2526

2627
import java.util.ArrayList;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Optional;
32+
import java.util.function.Function;
33+
import java.util.stream.Collectors;
3134

3235
import static org.hamcrest.Matchers.aMapWithSize;
3336
import static org.hamcrest.Matchers.anEmptyMap;
@@ -1153,6 +1156,73 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() {
11531156
assertThat(assignment.getReason().isPresent(), is(false));
11541157
}
11551158

1159+
public void testCopyAssignments() {
1160+
// Create test nodes
1161+
AssignmentPlan.Node node1 = new AssignmentPlan.Node("node-1", ByteSizeValue.ofGb(1).getBytes(), 4);
1162+
AssignmentPlan.Node node2 = new AssignmentPlan.Node("node-2", ByteSizeValue.ofGb(1).getBytes(), 8);
1163+
List<AssignmentPlan.Node> nodes = List.of(node1, node2);
1164+
1165+
// Create test deployments
1166+
AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
1167+
"model-1",
1168+
ByteSizeValue.ofMb(100).getBytes(),
1169+
2,
1170+
1,
1171+
Map.of(),
1172+
0,
1173+
null,
1174+
Priority.NORMAL,
1175+
0,
1176+
0
1177+
);
1178+
AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
1179+
1180+
"model-2",
1181+
ByteSizeValue.ofMb(100).getBytes(),
1182+
1,
1183+
2,
1184+
Map.of(),
1185+
0,
1186+
null,
1187+
Priority.LOW,
1188+
0,
1189+
0
1190+
);
1191+
List<AssignmentPlan.Deployment> deployments = List.of(deployment1, deployment2);
1192+
1193+
// Create source plan and assign models to nodes
1194+
AssignmentPlan.Builder sourceBuilder = AssignmentPlan.builder(nodes, deployments);
1195+
sourceBuilder.assignModelToNode(deployment1, node1, 1);
1196+
sourceBuilder.assignModelToNode(deployment1, node2, 1);
1197+
sourceBuilder.assignModelToNode(deployment2, node2, 1);
1198+
AssignmentPlan source = sourceBuilder.build();
1199+
1200+
// Create destination plan
1201+
AssignmentPlan.Builder dest = AssignmentPlan.builder(nodes, deployments);
1202+
1203+
// Create map of node IDs to original nodes
1204+
Map<String, AssignmentPlan.Node> originalNodeById = nodes.stream()
1205+
.collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
1206+
1207+
// Call copyAssignments
1208+
TrainedModelAssignmentRebalancer.copyAssignments(source, dest, originalNodeById);
1209+
1210+
// Build the destination plan
1211+
AssignmentPlan result = dest.build();
1212+
1213+
// Verify assignments
1214+
Optional<Map<AssignmentPlan.Node, Integer>> deployment1Assignments = result.assignments(deployment1);
1215+
assertThat(deployment1Assignments.isPresent(), is(true));
1216+
assertThat(deployment1Assignments.get().size(), equalTo(2));
1217+
assertThat(deployment1Assignments.get().get(node1), equalTo(1));
1218+
assertThat(deployment1Assignments.get().get(node2), equalTo(1));
1219+
1220+
Optional<Map<AssignmentPlan.Node, Integer>> deployment2Assignments = result.assignments(deployment2);
1221+
assertThat(deployment2Assignments.isPresent(), is(true));
1222+
assertThat(deployment2Assignments.get().size(), equalTo(1));
1223+
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
1224+
}
1225+
11561226
private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
11571227
return lowPriorityParams(deploymentId, deploymentId, modelSize);
11581228
}

0 commit comments

Comments
 (0)