Skip to content

Commit ff2cde3

Browse files
committed
Mirror upstream elastic#133930 as single snapshot commit for AI review
BASE=647356e7d47d947e4deb37c402242dba009b5233 HEAD=05ab306852611b2a29c53d6646a8664fc7e93676 Branch=main
1 parent 647356e commit ff2cde3

File tree

5 files changed

+89
-4
lines changed

5 files changed

+89
-4
lines changed

docs/changelog/133930.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 133930
2+
summary: Improve memory estimation methods accuracy in `TrainedModelAssignmentRebalancer`
3+
and related classes
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,28 @@ private Optional<String> explainAssignment(
406406
return Optional.of(load.getError());
407407
}
408408

409-
if (deployment.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) {
409+
// Find how many allocations already exist on this node
410+
// We need to search by node ID as assignmentPlan.assignments() returns a map
411+
// of AssignmentPlan.Node and the argument node of the DiscoveryNode
412+
int existingAllocationsOnNode = assignmentPlan.assignments(deployment)
413+
.map(
414+
assignments -> assignments.getOrDefault(
415+
assignments.keySet().stream().filter(n -> n.id().equals(node.getId())).findFirst().orElse(null),
416+
0
417+
)
418+
)
419+
.orElse(0);
420+
421+
// Calculate how many allocations remain to be assigned
422+
int unassignedAllocations = deployment.allocations() - assignmentPlan.totalAllocations(deployment);
423+
424+
// Check if there's enough memory for additional allocations
425+
long additionalMemory = deployment.estimateAdditionalMemoryUsageBytes(
426+
existingAllocationsOnNode,
427+
existingAllocationsOnNode + unassignedAllocations
428+
);
429+
long availableMemory = assignmentPlan.getRemainingNodeMemory(node.getId());
430+
if (additionalMemory > availableMemory) {
410431
// If any ML processes are running on a node we require some space to load the shared libraries.
411432
// So if none are currently running then this per-node overhead must be added to the requirement.
412433
// From node load we know if we had any jobs or models assigned before the rebalance.

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public long estimateMemoryUsageBytes(int allocations) {
107107
);
108108
}
109109

110-
long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
110+
public long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
111111
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
112112
modelId,
113113
memoryBytes,
@@ -308,7 +308,7 @@ private Quality computeQuality() {
308308
Node n = nodeAllocations.getKey();
309309
weighedAllocationsScore += (1 + 0.1 * (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * modelAssignments
310310
.get(n);
311-
memoryScore -= (nodeAllocations.getValue() > 0 ? m.memoryBytes() : 0);
311+
memoryScore -= (nodeAllocations.getValue() > 0 ? m.estimateMemoryUsageBytes(nodeAllocations.getValue()) : 0);
312312
}
313313
}
314314
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/LinearProgrammingPlanSolver.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class LinearProgrammingPlanSolver {
8181
long maxNodeMemory = nodes.stream().map(Node::availableMemoryBytes).max(Long::compareTo).orElse(0L);
8282
this.deployments = deployments.stream()
8383
// Filter out models that are not already assigned and do not fit on any node
84-
.filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.memoryBytes() <= maxNodeMemory)
84+
.filter(m -> m.currentAllocationsByNodeId().isEmpty() == false || m.minimumMemoryRequiredBytes() <= maxNodeMemory)
8585
// Also filter out models whose threads per allocation are more than the max node cores
8686
.filter(m -> m.threadsPerAllocation() <= maxNodeCores)
8787
.toList();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,64 @@ public void testCopyAssignments() {
11981198
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
11991199
}
12001200

1201+
public void testRebalance_GivenDeploymentWithMemoryRequirements_ConsidersNativeExecutableOverhead() {
1202+
// Create a node with just enough memory to fit the model plus native executable overhead
1203+
long modelMemory = ByteSizeValue.ofMb(200).getBytes();
1204+
long memoryOverhead = ByteSizeValue.ofMb(240).getBytes();
1205+
long nodeMemory = memoryOverhead + modelMemory * 2 + MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes();
1206+
1207+
DiscoveryNode node = buildNode("node-1", nodeMemory, 4);
1208+
1209+
String deploymentId = "model-with-overhead-test";
1210+
StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(deploymentId, deploymentId, modelMemory, 1, 1);
1211+
1212+
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
1213+
Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
1214+
1215+
// This node has no jobs or models yet, so the overhead should be accounted for
1216+
nodeLoads.put(node, NodeLoad.builder("node-1").setMaxMemory(nodeMemory).build());
1217+
1218+
TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
1219+
currentMetadata,
1220+
nodeLoads,
1221+
Map.of(List.of(), List.of(node)),
1222+
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1223+
1
1224+
).rebalance().build();
1225+
1226+
// Verify the deployment was successful
1227+
TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId);
1228+
assertThat(assignment, is(notNullValue()));
1229+
assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
1230+
assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1)));
1231+
assertThat(assignment.getNodeRoutingTable(), hasKey("node-1"));
1232+
assertThat(assignment.getReason().isPresent(), is(false));
1233+
1234+
// Now try with a node that has slightly less memory - this should fail
1235+
long insufficientNodeMemory = nodeMemory - ByteSizeValue.ofMb(21).getBytes();
1236+
DiscoveryNode insufficientNode = buildNode("node-2", insufficientNodeMemory, 4);
1237+
1238+
Map<DiscoveryNode, NodeLoad> insufficientNodeLoads = Map.of(
1239+
insufficientNode,
1240+
NodeLoad.builder("node-2").setMaxMemory(insufficientNodeMemory).build()
1241+
);
1242+
1243+
TrainedModelAssignmentMetadata insufficientResult = new TrainedModelAssignmentRebalancer(
1244+
TrainedModelAssignmentMetadata.Builder.empty().build(),
1245+
insufficientNodeLoads,
1246+
Map.of(List.of(), List.of(insufficientNode)),
1247+
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1248+
1
1249+
).rebalance().build();
1250+
1251+
TrainedModelAssignment insufficientAssignment = insufficientResult.getDeploymentAssignment(deploymentId);
1252+
assertThat(insufficientAssignment, is(notNullValue()));
1253+
assertThat(insufficientAssignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
1254+
assertThat(insufficientAssignment.getNodeRoutingTable(), is(anEmptyMap()));
1255+
assertThat(insufficientAssignment.getReason().isPresent(), is(true));
1256+
assertThat(insufficientAssignment.getReason().get(), containsString("insufficient available memory"));
1257+
}
1258+
12011259
private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
12021260
return lowPriorityParams(deploymentId, deploymentId, modelSize);
12031261
}

0 commit comments

Comments
 (0)