Skip to content

Commit d537939

Browse files
committed
fixed tests in TrainedModelAssignmentRebalancerTests and ZoneAwareAssignmentPlannerTests
1 parent a8cc385 commit d537939

File tree

2 files changed

+49
-30
lines changed

2 files changed

+49
-30
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,30 @@ private static AssignmentPlan mergePlans(
130130
return finalPlanBuilder.build();
131131
}
132132

133+
/**
134+
* Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
135+
*
136+
* @param source
137+
* @param dest
138+
* @param originalNodeById
139+
*/
133140
private static void copyAssignments(
134141
AssignmentPlan source,
135142
AssignmentPlan.Builder dest,
136143
Map<String, AssignmentPlan.Node> originalNodeById
137144
) {
138-
for (AssignmentPlan.Deployment m : source.deployments()) {
139-
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
140-
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
141-
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
142-
dest.assignModelToNode(m, originalNode, assignment.getValue());
143-
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
144-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
145-
// As the node has all its available memory we need to manually account memory of models with
146-
// current allocations.
147-
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
148-
dest.accountMemory(m, originalNode, requiredMemory);
149-
}
145+
for (AssignmentPlan.Deployment deployment : source.deployments()) {
146+
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
147+
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
148+
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
149+
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
150+
// if (deployment.currentAllocationsByNodeId().containsKey(node.id())) {
151+
// // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
152+
// // As the node has all its available memory we need to manually account memory of models with
153+
// // current allocations.
154+
// long requiredMemory = deployment.estimateMemoryUsageBytes(deployment.currentAllocationsByNodeId().get(node.id()));
155+
// dest.accountMemory(deployment, node, requiredMemory);
156+
// }
150157
}
151158
}
152159
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,36 +182,48 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
182182
List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
183183
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
184184
plan = preserveAllAllocations.mergePreservedAllocations(plan);
185-
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
185+
return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
186186
}
187187

188-
private AssignmentPlan swapOriginalModelsInPlan(
188+
/**
189+
* The method is responsible for reconstructing an AssignmentPlan
190+
* by replacing the deployments and nodes in the given plan with their original counterparts.
191+
* This ensures that the final plan uses the original objects while preserving the assignments
192+
* and memory accounting from the input plan.
193+
*
194+
* @param plan AssignmentPlan to reconstruct with original models and nodes
195+
* @param allNodes List of all nodes in the system, used to find original nodes
196+
* @param planDeployments List of deployments in the plan, not the original deployments
197+
* @return final plan with original models and nodes swapped in
198+
*/
199+
private AssignmentPlan swapOriginalDeploymentsInPlan(
189200
AssignmentPlan plan,
190201
List<Node> allNodes,
191202
List<AssignmentPlan.Deployment> planDeployments
192203
) {
193-
final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
204+
final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
194205
.collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, Function.identity()));
195206
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
196-
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
197-
for (AssignmentPlan.Deployment m : planDeployments) {
198-
AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.deploymentId());
199-
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
207+
AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);
208+
209+
for (AssignmentPlan.Deployment planDeployment : planDeployments) {
210+
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId());
211+
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
200212
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
201213
Node originalNode = originalNodeById.get(assignment.getKey().id());
202-
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
203-
if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
204-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
205-
// As the node has all its available memory we need to manually account memory of models with
206-
// current allocations.
207-
long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
208-
originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
209-
);
210-
planBuilder.accountMemory(m, originalNode, requiredMemory);
211-
}
214+
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
215+
// if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
216+
// // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
217+
// // As the node has all its available memory we need to manually account memory of models with
218+
// // current allocations.
219+
// long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
220+
// originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
221+
// );
222+
// finalPlanBuilder.accountMemory(planDeployment, originalNode, requiredMemory);
223+
// }
212224
}
213225
}
214-
return planBuilder.build();
226+
return finalPlanBuilder.build();
215227
}
216228

217229
private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {

0 commit comments

Comments
 (0)