Skip to content

Commit 817c1c0

Browse files
committed
Test before model assignment.
1 parent 697622d commit 817c1c0

File tree

4 files changed

+93
-7
lines changed

4 files changed

+93
-7
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest,
138138
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
139139
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
140140
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
141-
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
141+
if(dest.canAssign(deployment, node, sourceAssignment.getValue())) {
142+
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
143+
}
142144
}
143145
}
144146
}
@@ -318,7 +320,12 @@ private Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap() {
318320
}
319321

320322
private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(NodeLoad load) {
321-
return load.getFreeMemoryExcludingPerNodeOverhead() - load.getAssignedNativeInferenceMemory();
323+
// load.getFreeMemoryExcludingPerNodeOverhead() = maxMemory - assignedJobMemoryExcludingPerNodeOverhead - 30MB native executable code overhead
324+
// assignedJobMemoryExcludingPerNodeOverhead = assignedAnomalyDetectorMemory + assignedDataFrameAnalyticsMemory + assignedNativeInferenceMemory
325+
// load.getAssignedNativeInferenceMemory() = assignedNativeInferenceMemory
326+
// TODO: (valeriy) assignedNativeInferenceMemory is double counted in the current calculation.
327+
return load.getFreeMemoryExcludingPerNodeOverhead()/* - load.getAssignedNativeInferenceMemory()*/;
328+
322329
}
323330

324331
private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(AssignmentPlan assignmentPlan) {
@@ -405,8 +412,17 @@ private Optional<String> explainAssignment(
405412
if (Strings.isNullOrEmpty(load.getError()) == false) {
406413
return Optional.of(load.getError());
407414
}
408-
409-
if (deployment.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) {
415+
// TODO (valeriy): this test should be actually true, but it is false, because we use the "naked" deployment footprint
416+
// Get existing allocations for this node to avoid double counting
417+
int existingAllocationsOnNode = assignmentPlan.assignments(deployment)
418+
.flatMap(assignments -> assignments.entrySet().stream()
419+
.filter(entry -> entry.getKey().id().equals(node.getId()))
420+
.findFirst()
421+
.map(Map.Entry::getValue))
422+
.orElse(0);
423+
int notYetAssignedAllocations = deployment.allocations() - assignmentPlan.totalAllocations(deployment);
424+
// if (deployment.estimateMemoryUsageBytes(deployment.allocations() - existingAllocationsOnNode) > assignmentPlan.getRemainingNodeMemory(node.getId())) {
425+
if (deployment.estimateAdditionalMemoryUsageBytes(existingAllocationsOnNode, existingAllocationsOnNode + notYetAssignedAllocations) > assignmentPlan.getRemainingNodeMemory(node.getId())) {
410426
// If any ML processes are running on a node we require some space to load the shared libraries.
411427
// So if none are currently running then this per-node overhead must be added to the requirement.
412428
// From node load we know if we had any jobs or models assigned before the rebalance.

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public long estimateMemoryUsageBytes(int allocations) {
107107
);
108108
}
109109

110-
long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
110+
public long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
111111
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
112112
modelId,
113113
memoryBytes,
@@ -417,7 +417,7 @@ int getRemainingAllocations(Deployment m) {
417417
return remainingModelAllocations.get(m);
418418
}
419419

420-
boolean canAssign(Deployment deployment, Node node, int allocations) {
420+
public boolean canAssign(Deployment deployment, Node node, int allocations) {
421421
long requiredMemory = getDeploymentMemoryRequirement(deployment, node, allocations);
422422
return canAssign(deployment, node, allocations, requiredMemory);
423423
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ private AssignmentPlan swapOriginalDeploymentsInPlan(
211211
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
212212
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
213213
Node originalNode = originalNodeById.get(assignment.getKey().id());
214-
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
214+
if (finalPlanBuilder.canAssign(originalDeployment, originalNode, assignment.getValue())) {
215+
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
216+
}
215217
}
216218
}
217219
return finalPlanBuilder.build();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,4 +1262,72 @@ private static DiscoveryNode buildNode(String name, long nativeMemory, int alloc
12621262
)
12631263
.build();
12641264
}
1265+
1266+
public void testRebalance_GivenDeploymentWithMemoryRequirements_ConsidersNativeExecutableOverhead() {
1267+
// Create a node with just enough memory to fit the model plus native executable overhead
1268+
long modelMemory = ByteSizeValue.ofMb(200).getBytes();
1269+
long memoryOverhead = ByteSizeValue.ofMb(240).getBytes();
1270+
long JVMOverhead = ByteSizeValue.ofMb(50).getBytes();
1271+
long nodeMemory = memoryOverhead + modelMemory*2 + JVMOverhead;
1272+
1273+
DiscoveryNode node = buildNode("node-1", nodeMemory, 4);
1274+
1275+
String deploymentId = "model-with-overhead-test";
1276+
StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(
1277+
deploymentId,
1278+
deploymentId,
1279+
modelMemory,
1280+
1,
1281+
1
1282+
);
1283+
1284+
TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
1285+
Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
1286+
1287+
// This node has no jobs or models yet, so the overhead should be accounted for
1288+
nodeLoads.put(node, NodeLoad.builder("node-1")
1289+
.setMaxMemory(nodeMemory)
1290+
.build());
1291+
1292+
TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
1293+
currentMetadata,
1294+
nodeLoads,
1295+
Map.of(List.of(), List.of(node)),
1296+
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1297+
1
1298+
).rebalance().build();
1299+
1300+
// Verify the deployment was successful
1301+
TrainedModelAssignment assignment = result.getDeploymentAssignment(deploymentId);
1302+
assertThat(assignment, is(notNullValue()));
1303+
assertThat(assignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
1304+
assertThat(assignment.getNodeRoutingTable(), is(aMapWithSize(1)));
1305+
assertThat(assignment.getNodeRoutingTable(), hasKey("node-1"));
1306+
assertThat(assignment.getReason().isPresent(), is(false));
1307+
1308+
// Now try with a node that has slightly less memory - this should fail
1309+
long insufficientNodeMemory = nodeMemory - ByteSizeValue.ofMb(21).getBytes();
1310+
DiscoveryNode insufficientNode = buildNode("node-2", insufficientNodeMemory, 4);
1311+
1312+
Map<DiscoveryNode, NodeLoad> insufficientNodeLoads = Map.of(
1313+
insufficientNode, NodeLoad.builder("node-2")
1314+
.setMaxMemory(insufficientNodeMemory)
1315+
.build()
1316+
);
1317+
1318+
TrainedModelAssignmentMetadata insufficientResult = new TrainedModelAssignmentRebalancer(
1319+
TrainedModelAssignmentMetadata.Builder.empty().build(),
1320+
insufficientNodeLoads,
1321+
Map.of(List.of(), List.of(insufficientNode)),
1322+
Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)),
1323+
1)
1324+
.rebalance().build();
1325+
1326+
TrainedModelAssignment insufficientAssignment = insufficientResult.getDeploymentAssignment(deploymentId);
1327+
assertThat(insufficientAssignment, is(notNullValue()));
1328+
assertThat(insufficientAssignment.getAssignmentState(), equalTo(AssignmentState.STARTING));
1329+
assertThat(insufficientAssignment.getNodeRoutingTable(), is(anEmptyMap()));
1330+
assertThat(insufficientAssignment.getReason().isPresent(), is(true));
1331+
assertThat(insufficientAssignment.getReason().get(), containsString("insufficient available memory"));
1332+
}
12651333
}

0 commit comments

Comments
 (0)