Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/133930.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 133930
summary: Improve memory estimation methods accuracy in `TrainedModelAssignmentRebalancer`
and related classes
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,28 @@ private Optional<String> explainAssignment(
return Optional.of(load.getError());
}

if (deployment.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) {
// Find how many allocations already exist on this node
// We need to search by node ID as assignmentPlan.assignments() returns a map
// of AssignmentPlan.Node and the argument node of the DiscoveryNode
int existingAllocationsOnNode = assignmentPlan.assignments(deployment)
.map(
assignments -> assignments.getOrDefault(
assignments.keySet().stream().filter(n -> n.id().equals(node.getId())).findFirst().orElse(null),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assignments is a Map, right?

So why not do assignment.getOrDefault(node, 0) instead of streaming/filtering the key set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assignments is a <Map<AssignmentPlan.Node, Integer>, while node is of type DiscoveryNode. That's why I need to compare both id's.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks.

Just thinking out loud: shouldn't the return value of assignmentPlan.assignments be a Map<String, Integer> instead (the string being the node ID)? That sounds more useful. Is that a big refactoring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AssignmentPlan.assignments(deployment) is used in 10 places in the main code and in 100 places in the test code. We can check if we can refactor it, but it should be in a different PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I agree with that. Then please add a comment here about this Node vs DiscoveryNode and that it could benefit from refactoring (to key string node ID) and it lgtm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #134030 so it won't get lost.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another refactor to consider is making the explainAssignment() function part of the AssignmentPlan class. The code here is trying to reverse engineer the planners decision making and it's easy to get out of sync.

0
)
)
.orElse(0);

// Calculate how many allocations remain to be assigned
int unassignedAllocations = deployment.allocations() - assignmentPlan.totalAllocations(deployment);

// Check if there's enough memory for additional allocations
long additionalMemory = deployment.estimateAdditionalMemoryUsageBytes(
existingAllocationsOnNode,
existingAllocationsOnNode + unassignedAllocations
);
long availableMemory = assignmentPlan.getRemainingNodeMemory(node.getId());
if (additionalMemory > availableMemory) {
// If any ML processes are running on a node we require some space to load the shared libraries.
// So if none are currently running then this per-node overhead must be added to the requirement.
// From node load we know if we had any jobs or models assigned before the rebalance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public long estimateMemoryUsageBytes(int allocations) {
);
}

long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
public long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
memoryBytes,
Expand Down Expand Up @@ -308,7 +308,7 @@ private Quality computeQuality() {
Node n = nodeAllocations.getKey();
weighedAllocationsScore += (1 + 0.1 * (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * modelAssignments
.get(n);
memoryScore -= (nodeAllocations.getValue() > 0 ? m.memoryBytes() : 0);
memoryScore -= (nodeAllocations.getValue() > 0 ? m.estimateMemoryUsageBytes(nodeAllocations.getValue()) : 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AssigmentPlan.Deployment::memoryBytes() is trappy as estimateMemoryUsageBytes() should always be used instead.

Because AssigmentPlan.Deployment is a record it will always have a public accessor for the memoryBytes field. The only way to stop people using it that I can think of is to override the accessor

        @Override
        public long memoryBytes() {
            throw new UnsupportedOperationException("use estimateMemoryUsageBytes(int allocations) instead");
        }

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class LinearProgrammingPlanSolver {
long maxNodeMemory = nodes.stream().map(Node::availableMemoryBytes).max(Long::compareTo).orElse(0L);
this.deployments = deployments.stream()
// Filter out models that are not already assigned and do not fit on any node
.filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.memoryBytes() <= maxNodeMemory)
.filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.minimumMemoryRequiredBytes() <= maxNodeMemory)
// Also filter out models whose threads per allocation are more than the max node cores
.filter(m -> m.threadsPerAllocation() <= maxNodeCores)
.toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,64 @@ public void testCopyAssignments() {
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
}

public void testRebalance_GivenDeploymentWithMemoryRequirements_ConsidersNativeExecutableOverhead() {
// Create a node with just enough memory to fit the model plus native executable overhead
long modelMemory = ByteSizeValue.ofMb(200).getBytes();
long memoryOverhead = ByteSizeValue.ofMb(240).getBytes();
long nodeMemory = memoryOverhead + modelMemory * 2 + MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes();

DiscoveryNode node = buildNode("node-1", nodeMemory, 4);

String deploymentId = "model-with-overhead-test";
StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(deploymentId, deploymentId, modelMemory, 1, 1);

TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();

// This node has no jobs or models yet, so the overhead should be accounted for
nodeLoads.put(node, NodeLoad.builder("node-1").setMaxMemory(nodeMemory).build());

TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
currentMetadata,
nodeLoads,
Map.of(List.of(), List.of(node)),
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1
).rebalance().build();

// Verify the deployment was successful
TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId);
assertThat(assignment, is(notNullValue()));
assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1)));
assertThat(assignment.getNodeRoutingTable(), hasKey("node-1"));
assertThat(assignment.getReason().isPresent(), is(false));

// Now try with a node that has slightly less memory - this should fail
long insufficientNodeMemory = nodeMemory - ByteSizeValue.ofMb(21).getBytes();
DiscoveryNode insufficientNode = buildNode("node-2", insufficientNodeMemory, 4);

Map<DiscoveryNode, NodeLoad> insufficientNodeLoads = Map.of(
insufficientNode,
NodeLoad.builder("node-2").setMaxMemory(insufficientNodeMemory).build()
);

TrainedModelAssignmentMetadata insufficientResult = new TrainedModelAssignmentRebalancer(
TrainedModelAssignmentMetadata.Builder.empty().build(),
insufficientNodeLoads,
Map.of(List.of(), List.of(insufficientNode)),
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1
).rebalance().build();

TrainedModelAssignment insufficientAssignment = insufficientResult.getDeploymentAssignment(deploymentId);
assertThat(insufficientAssignment, is(notNullValue()));
assertThat(insufficientAssignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
assertThat(insufficientAssignment.getNodeRoutingTable(), is(anEmptyMap()));
assertThat(insufficientAssignment.getReason().isPresent(), is(true));
assertThat(insufficientAssignment.getReason().get(), containsString("insufficient available memory"));
}

private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
return lowPriorityParams(deploymentId, deploymentId, modelSize);
}
Expand Down