diff --git a/docs/changelog/133930.yaml b/docs/changelog/133930.yaml new file mode 100644 index 0000000000000..b8334341cf8bc --- /dev/null +++ b/docs/changelog/133930.yaml @@ -0,0 +1,6 @@ +pr: 133930 +summary: Improve memory estimation methods accuracy in `TrainedModelAssignmentRebalancer` + and related classes +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index f523b4b086f35..196276a433d62 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -406,7 +406,28 @@ private Optional explainAssignment( return Optional.of(load.getError()); } - if (deployment.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) { + // Find how many allocations already exist on this node + // We need to search by node ID as assignmentPlan.assignments() returns a map + // of AssignmentPlan.Node and the argument node of the DiscoveryNode + int existingAllocationsOnNode = assignmentPlan.assignments(deployment) + .map( + assignments -> assignments.getOrDefault( + assignments.keySet().stream().filter(n -> n.id().equals(node.getId())).findFirst().orElse(null), + 0 + ) + ) + .orElse(0); + + // Calculate how many allocations remain to be assigned + int unassignedAllocations = deployment.allocations() - assignmentPlan.totalAllocations(deployment); + + // Check if there's enough memory for additional allocations + long additionalMemory = deployment.estimateAdditionalMemoryUsageBytes( + existingAllocationsOnNode, + existingAllocationsOnNode + unassignedAllocations + ); + long availableMemory = assignmentPlan.getRemainingNodeMemory(node.getId()); + if (additionalMemory > availableMemory) { // If any ML processes are running on a node we require some space to load the shared libraries. // So if none are currently running then this per-node overhead must be added to the requirement. // From node load we know if we had any jobs or models assigned before the rebalance. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index a90a8cb9d5262..063014d616925 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -107,7 +107,7 @@ public long estimateMemoryUsageBytes(int allocations) { ); } - long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { + public long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) { return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( modelId, memoryBytes, @@ -308,7 +308,7 @@ private Quality computeQuality() { Node n = nodeAllocations.getKey(); weighedAllocationsScore += (1 + 0.1 * (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * modelAssignments .get(n); - memoryScore -= (nodeAllocations.getValue() > 0 ? m.memoryBytes() : 0); + memoryScore -= (nodeAllocations.getValue() > 0 ? m.estimateMemoryUsageBytes(nodeAllocations.getValue()) : 0); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java index 90b3d3590a254..0b1ca67490bfd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java @@ -81,7 +81,7 @@ class LinearProgrammingPlanSolver { long maxNodeMemory = nodes.stream().map(Node::availableMemoryBytes).max(Long::compareTo).orElse(0L); this.deployments = deployments.stream() // Filter out models that are not already assigned and do not fit on any node - .filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.memoryBytes() <= maxNodeMemory) + .filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.minimumMemoryRequiredBytes() <= maxNodeMemory) // Also filter out models whose threads per allocation are more than the max node cores .filter(m -> m.threadsPerAllocation() <= maxNodeCores) .toList(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index b873493100798..65e91f8402ce5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -1198,6 +1198,64 @@ public void testCopyAssignments() { assertThat(deployment2Assignments.get().get(node2), equalTo(1)); } + public void testRebalance_GivenDeploymentWithMemoryRequirements_ConsidersNativeExecutableOverhead() { + // Create a node with just enough memory to fit the model plus native executable overhead + long modelMemory = ByteSizeValue.ofMb(200).getBytes(); + long memoryOverhead = ByteSizeValue.ofMb(240).getBytes(); + long nodeMemory = memoryOverhead + modelMemory * 2 + MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(); + + DiscoveryNode node = buildNode("node-1", nodeMemory, 4); + + String deploymentId = "model-with-overhead-test"; + StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(deploymentId, deploymentId, modelMemory, 1, 1); + + TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build(); + Map nodeLoads = new HashMap<>(); + + // This node has no jobs or models yet, so the overhead should be accounted for + nodeLoads.put(node, NodeLoad.builder("node-1").setMaxMemory(nodeMemory).build()); + + TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( + currentMetadata, + nodeLoads, + Map.of(List.of(), List.of(node)), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), + 1 + ).rebalance().build(); + + // Verify the deployment was successful + TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId); + assertThat(assignment, is(notNullValue())); + assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1))); + assertThat(assignment.getNodeRoutingTable(), hasKey("node-1")); + assertThat(assignment.getReason().isPresent(), is(false)); + + // Now try with a node that has slightly less memory - this should fail + long insufficientNodeMemory = nodeMemory - ByteSizeValue.ofMb(21).getBytes(); + DiscoveryNode insufficientNode = buildNode("node-2", insufficientNodeMemory, 4); + + Map insufficientNodeLoads = Map.of( + insufficientNode, + NodeLoad.builder("node-2").setMaxMemory(insufficientNodeMemory).build() + ); + + TrainedModelAssignmentMetadata insufficientResult = new TrainedModelAssignmentRebalancer( + TrainedModelAssignmentMetadata.Builder.empty().build(), + insufficientNodeLoads, + Map.of(List.of(), List.of(insufficientNode)), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), + 1 + ).rebalance().build(); + + TrainedModelAssignment insufficientAssignment = insufficientResult.getDeploymentAssignment(deploymentId); + assertThat(insufficientAssignment, is(notNullValue())); + assertThat(insufficientAssignment.getAssignmentState(), equalTo(AssignmentState.STARTING)); + assertThat(insufficientAssignment.getNodeRoutingTable(), is(anEmptyMap())); + assertThat(insufficientAssignment.getReason().isPresent(), is(true)); + assertThat(insufficientAssignment.getReason().get(), containsString("insufficient available memory")); + } + private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) { return lowPriorityParams(deploymentId, deploymentId, modelSize); }