Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133916.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133916
summary: Fix model assignment error handling and assignment explanation generation
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
0
);

long requiredMemory = mergedPlanBuilder.getDeploymentMemoryRequirement(deploymentNewAllocations, n, newAllocations);
if (newAllocations > 0 && mergedPlanBuilder.canAssign(deploymentNewAllocations, n, newAllocations, requiredMemory)) {
if (newAllocations > 0) {
mergedPlanBuilder.assignModelToNode(deploymentNewAllocations, n, newAllocations);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,28 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
}

public Builder assignModelToNode(Deployment deployment, Node node, int allocations, long requiredMemory) {
if (allocations <= 0) {
if (allocations <= 0 || canAssign(deployment, node, allocations, requiredMemory) == false) {
return this;
}

validateAssignment(deployment, node, allocations);

assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
accountMemory(deployment, node, requiredMemory);

if (deployment.priority == Priority.NORMAL) {
remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation());
}
remainingModelAllocations.compute(deployment, (m, remModelThreads) -> remModelThreads - allocations);
return this;
}

void validateAssignment(Deployment deployment, Node node, int allocations) {
long requiredMemory = getDeploymentMemoryRequirement(deployment, node, allocations);
validateAssignment(deployment, node, allocations, requiredMemory);
}

private void validateAssignment(Deployment deployment, Node node, int allocations, long requiredMemory) {
if (requiredMemory > remainingNodeMemory.get(node)) {
throw new IllegalArgumentException(
"not enough memory on node ["
Expand All @@ -455,6 +474,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
+ "]"
);
}

if (deployment.priority == Priority.NORMAL && allocations * deployment.threadsPerAllocation() > remainingNodeCores.get(node)) {
throw new IllegalArgumentException(
"not enough cores on node ["
Expand All @@ -468,15 +488,6 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
+ "]"
);
}

assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
accountMemory(deployment, node, requiredMemory);

if (deployment.priority == Priority.NORMAL) {
remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation());
}
remainingModelAllocations.compute(deployment, (m, remModelThreads) -> remModelThreads - allocations);
return this;
}

private int getAssignedAllocations(Deployment deployment, Node node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,7 @@ private void unassignOversizedModels(Node n) {
private AssignmentPlan toPlan() {
AssignmentPlan.Builder builder = AssignmentPlan.builder(nodes, deployments);
for (Map.Entry<Tuple<AssignmentPlan.Deployment, Node>, Integer> assignment : tryAssigningRemainingCores().entrySet()) {
// TODO (#101612) The model should be assigned to the node only when it is possible. This means, that canAssign should be
// integrated into the assignModelToNode.
if (builder.canAssign(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue())) {
builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue());
}
builder.assignModelToNode(assignment.getKey().v1(), assignment.getKey().v2(), assignment.getValue());
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels)
remainingZones,
tryAssigningPreviouslyAssignedModels
);

// Update remaining allocations to account for allocations satisfied in this zone
plan.deployments()
.forEach(
d -> deploymentIdToRemainingAllocations.computeIfPresent(
Expand Down Expand Up @@ -217,6 +219,14 @@ private AssignmentPlan swapOriginalDeploymentsInPlan(
return finalPlanBuilder.build();
}

/**
* The mergeAllocationsByNodeIdByDeploymentId method is responsible for consolidating allocation data
* from multiple AssignmentPlan objects into a single structure. This structure maps deployment IDs
* to their respective node allocations, allowing the system to track how resources are distributed
* across nodes for each deployment.
* @param plans List of AssignmentPlan objects to merge allocations from
* @return
*/
private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {
Map<String, Map<String, Integer>> allocationsByNodeIdByDeploymentId = new HashMap<>();
deployments.forEach(d -> allocationsByNodeIdByDeploymentId.put(d.deploymentId(), new HashMap<>()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ public void testCopyAssignments() {
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
}

public void testRebalance_GivenDeploymentWithMemoryRequirements_ConsidersNativeExecutableOverhead() {
public void testRebalance_GivenDeploymentWithMemoryRequirements_ExplainMissingAllocations() {
// Create a node with just enough memory to fit the model plus native executable overhead
long modelMemory = ByteSizeValue.ofMb(200).getBytes();
long memoryOverhead = ByteSizeValue.ofMb(240).getBytes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() {
Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, null, 0, 0);

AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.validateAssignment(m, n, 1));

assertThat(e.getMessage(), equalTo("not enough memory on node [n_1] to assign [1] allocations to deployment [m_1]"));
}
Expand All @@ -261,7 +261,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() {

AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));

Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 2));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.validateAssignment(m, n, 2));
assertThat(e.getMessage(), containsString("not enough memory on node"));
}
{ // new memory format
Expand All @@ -281,7 +281,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() {

AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));

Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 2));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.validateAssignment(m, n, 2));
assertThat(e.getMessage(), containsString("not enough memory on node"));
}
}
Expand All @@ -291,7 +291,7 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati
Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, null, 0, 0);

AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.validateAssignment(m, n, 5));

assertThat(
e.getMessage(),
Expand All @@ -315,7 +315,7 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAlloc
);

AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3));
Exception e = expectThrows(IllegalArgumentException.class, () -> builder.validateAssignment(m, n, 3));

assertThat(
e.getMessage(),
Expand Down