Skip to content

Commit 43c81cb

Browse files
committed
Merge branch 'main' into esql_mark_spec_tests
2 parents 39b6594 + 80c47f3 commit 43c81cb

File tree

6 files changed

+182
-83
lines changed

6 files changed

+182
-83
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,15 @@ private static AssignmentPlan mergePlans(
130130
return finalPlanBuilder.build();
131131
}
132132

133-
private static void copyAssignments(
134-
AssignmentPlan source,
135-
AssignmentPlan.Builder dest,
136-
Map<String, AssignmentPlan.Node> originalNodeById
137-
) {
138-
for (AssignmentPlan.Deployment m : source.deployments()) {
139-
Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
140-
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
141-
AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
142-
dest.assignModelToNode(m, originalNode, assignment.getValue());
143-
if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
144-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
145-
// As the node has all its available memory we need to manually account memory of models with
146-
// current allocations.
147-
long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
148-
dest.accountMemory(m, originalNode, requiredMemory);
149-
}
133+
/**
134+
* Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
135+
*/
136+
static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
137+
for (AssignmentPlan.Deployment deployment : source.deployments()) {
138+
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
139+
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
140+
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
141+
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
150142
}
151143
}
152144
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -223,23 +223,43 @@ public int compareTo(AssignmentPlan o) {
223223
return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
224224
}
225225

226+
/**
227+
* Checks whether all deployments in the current {@link AssignmentPlan} have at least as many
228+
* allocations as currently assigned.
229+
*/
226230
public boolean satisfiesCurrentAssignments() {
227231
return deployments().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
228232
}
229233

234+
/**
235+
* Checks whether the current assignments for a given {@link Deployment} meet its allocation requirements.
236+
*
237+
* It ensures that the total number of allocations assigned to the deployment across all nodes is
238+
* at least equal to the deployment's current assigned allocations.
239+
*/
230240
private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
231241
if (m.currentAllocationsByNodeId().isEmpty()) {
232242
return true;
233243
}
234244
Map<Node, Integer> nodeAssignments = assignments.get(m);
235-
int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
236-
return currentAllocations >= m.getCurrentAssignedAllocations();
245+
int inPlanAssignedAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
246+
return inPlanAssignedAllocations >= m.getCurrentAssignedAllocations();
237247
}
238248

239-
public boolean satisfiesAllocations(Deployment m) {
240-
return remainingModelAllocations.getOrDefault(m, 0) == 0;
249+
/**
250+
* Checks if the current assignments satisfy the deployment's allocation requirements.
251+
* @param deployment the deployment to check
252+
* @return true if the current assignments satisfy the deployment's allocation requirements, false otherwise
253+
*/
254+
public boolean satisfiesAllocations(Deployment deployment) {
255+
return remainingModelAllocations.getOrDefault(deployment, 0) == 0;
241256
}
242257

258+
/**
259+
* Checks if the current assignments satisfy all deployments' allocation requirements. This means that
260+
* each deployment has no remaining allocations left to assign.
261+
* @return true if the current assignments satisfy the deployments' allocation requirements, false otherwise
262+
*/
243263
public boolean satisfiesAllModels() {
244264
return deployments().stream().allMatch(this::satisfiesAllocations);
245265
}
@@ -424,8 +444,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
424444
if (allocations <= 0) {
425445
return this;
426446
}
427-
if (/*isAlreadyAssigned(deployment, node) == false
428-
&&*/ requiredMemory > remainingNodeMemory.get(node)) {
447+
if (requiredMemory > remainingNodeMemory.get(node)) {
429448
throw new IllegalArgumentException(
430449
"not enough memory on node ["
431450
+ node.id()
@@ -450,7 +469,7 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
450469
);
451470
}
452471

453-
assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
472+
assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
454473
accountMemory(deployment, node, requiredMemory);
455474

456475
if (deployment.priority == Priority.NORMAL) {
@@ -461,23 +480,10 @@ public Builder assignModelToNode(Deployment deployment, Node node, int allocatio
461480
}
462481

463482
private int getAssignedAllocations(Deployment deployment, Node node) {
464-
int currentAllocations = getCurrentAllocations(deployment, node);
465-
int assignmentAllocations = assignments.get(deployment).get(node);
466-
return currentAllocations + assignmentAllocations;
467-
}
468-
469-
private static int getCurrentAllocations(Deployment m, Node n) {
470-
return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
471-
}
472-
473-
public void accountMemory(Deployment m, Node n) {
474-
// TODO (#101612) remove or refactor unused method
475-
long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
476-
accountMemory(m, n, requiredMemory);
483+
return assignments.get(deployment).get(node);
477484
}
478485

479486
public void accountMemory(Deployment m, Node n, long requiredMemory) {
480-
// TODO (#101612) computation of required memory should be done internally
481487
remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
482488
if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
483489
throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.deploymentId() + "]");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,26 @@ public AssignmentPlan computePlan() {
6060
return computePlan(true);
6161
}
6262

63-
public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
63+
/**
64+
* Computes an {@link AssignmentPlan} for the given nodes and deployments.
65+
* If {@code tryAssigningAllPreviouslyAllocatedModels} is true, then the plan will
66+
* attempt to assign at least one allocation to previously assigned models.
67+
* Otherwise, it will only ensure that deployments assigned to existing nodes will preserve at least one allocation
68+
*
69+
* @param tryAssigningAllPreviouslyAllocatedModels whether to do the best effort assigning previously assigned models somewhere
70+
* with at least one allocation
71+
* @return the computed assignment plan
72+
*/
73+
public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
6474
logger.debug(() -> format("Computing plan for nodes = %s; deployments = %s", nodes, deployments));
6575

6676
AssignmentPlan bestPlan;
6777
AssignmentPlan planSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
6878
logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
69-
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() == false && tryAssigningPreviouslyAssignedModels) {
79+
if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || tryAssigningAllPreviouslyAllocatedModels == false) {
80+
bestPlan = planSatisfyingCurrentAssignments;
81+
} else {
82+
// try to reuse any deployment that would otherwise drop to zero allocations
7083
AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated =
7184
solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
7285
logger.debug(
@@ -82,28 +95,37 @@ public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels)
8295
? planSatisfyingCurrentAssignments
8396
: planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
8497
}
85-
} else {
86-
bestPlan = planSatisfyingCurrentAssignments;
8798
}
8899

89100
logger.debug(() -> "Best plan =\n" + bestPlan.prettyPrint());
90101
logger.debug(() -> prettyPrintOverallStats(bestPlan));
91102
return bestPlan;
92103
}
93104

105+
/**
106+
* Computes the best assignment plan from two strategies:
107+
* 1. Preserving one allocation on current assignments, which is the most flexible
108+
* 2. Preserving all allocations on current assignments, which is more conservative
109+
* @return the best assignment plan
110+
*/
94111
private AssignmentPlan solveSatisfyingCurrentAssignments() {
95112
AssignmentPlan bestPlan;
96113
// First solve preserving one allocation per assignment because that is most flexible
97114
AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
98-
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
115+
116+
if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
117+
// If the plan satisfies all models, then we can use it as is
118+
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
119+
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
120+
// If in the new assignment plan, some deployments have fewer allocations than in the current assignments,
121+
// try explicitly preserving all allocations on current assignments.
99122
bestPlan = solvePreservingAllAllocationsOnCurrentAssignments();
100-
} else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() == false) {
123+
} else {
124+
// Choose the best strategy according to {@link AssignmentPlan#computeQuality(AssignmentPlan)}
101125
AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
102126
bestPlan = planKeepingAllAllocationsOnCurrentAssignments.compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0
103127
? planKeepingAllAllocationsOnCurrentAssignments
104128
: planKeepingOneAllocationOnCurrentAssignments;
105-
} else {
106-
bestPlan = planKeepingOneAllocationOnCurrentAssignments;
107129
}
108130
return bestPlan;
109131
}
@@ -120,7 +142,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat
120142
1,
121143
m.threadsPerAllocation(),
122144
// don't rely on the current allocation
123-
new HashMap<>(),
145+
Map.of(),
124146
m.maxAssignedAllocations(),
125147
m.getAdaptiveAllocationsSettings(),
126148
m.perDeploymentMemoryBytes(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,36 +182,39 @@ private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
182182
List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
183183
AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
184184
plan = preserveAllAllocations.mergePreservedAllocations(plan);
185-
return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
185+
return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
186186
}
187187

188-
private AssignmentPlan swapOriginalModelsInPlan(
188+
/**
189+
* The method is responsible for reconstructing an AssignmentPlan
190+
* by replacing the deployments and nodes in the given plan with their original counterparts.
191+
* This ensures that the final plan uses the original objects while preserving the assignments
192+
* and memory accounting from the input plan.
193+
*
194+
* @param plan AssignmentPlan to reconstruct with original models and nodes
195+
* @param allNodes List of all nodes in the system, used to find original nodes
196+
* @param planDeployments List of deployments in the plan, not the original deployments
197+
* @return final plan with original models and nodes swapped in
198+
*/
199+
private AssignmentPlan swapOriginalDeploymentsInPlan(
189200
AssignmentPlan plan,
190201
List<Node> allNodes,
191202
List<AssignmentPlan.Deployment> planDeployments
192203
) {
193-
final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
204+
final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
194205
.collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, Function.identity()));
195206
final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
196-
AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
197-
for (AssignmentPlan.Deployment m : planDeployments) {
198-
AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.deploymentId());
199-
Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
207+
AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);
208+
209+
for (AssignmentPlan.Deployment planDeployment : planDeployments) {
210+
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId());
211+
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
200212
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
201213
Node originalNode = originalNodeById.get(assignment.getKey().id());
202-
planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
203-
if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
204-
// TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
205-
// As the node has all its available memory we need to manually account memory of models with
206-
// current allocations.
207-
long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
208-
originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
209-
);
210-
planBuilder.accountMemory(m, originalNode, requiredMemory);
211-
}
214+
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
212215
}
213216
}
214-
return planBuilder.build();
217+
return finalPlanBuilder.build();
215218
}
216219

217220
private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
2222
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
2323
import org.elasticsearch.xpack.ml.MachineLearning;
24+
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
2425
import org.elasticsearch.xpack.ml.job.NodeLoad;
2526

2627
import java.util.ArrayList;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Optional;
32+
import java.util.function.Function;
33+
import java.util.stream.Collectors;
3134

3235
import static org.hamcrest.Matchers.aMapWithSize;
3336
import static org.hamcrest.Matchers.anEmptyMap;
@@ -1127,6 +1130,74 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() {
11271130
assertThat(assignment.getReason().isPresent(), is(false));
11281131
}
11291132

1133+
public void testCopyAssignments() {
1134+
// Create test nodes
1135+
AssignmentPlan.Node node1 = new AssignmentPlan.Node("node-1", ByteSizeValue.ofGb(1).getBytes(), 4);
1136+
AssignmentPlan.Node node2 = new AssignmentPlan.Node("node-2", ByteSizeValue.ofGb(1).getBytes(), 8);
1137+
List<AssignmentPlan.Node> nodes = List.of(node1, node2);
1138+
1139+
// Create test deployments
1140+
AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
1141+
"deployment-1",
1142+
"model-1",
1143+
ByteSizeValue.ofMb(100).getBytes(),
1144+
2,
1145+
1,
1146+
Map.of(),
1147+
0,
1148+
null,
1149+
Priority.NORMAL,
1150+
0,
1151+
0
1152+
);
1153+
AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
1154+
"deployment-2",
1155+
"model-2",
1156+
ByteSizeValue.ofMb(100).getBytes(),
1157+
1,
1158+
2,
1159+
Map.of(),
1160+
0,
1161+
null,
1162+
Priority.LOW,
1163+
0,
1164+
0
1165+
);
1166+
List<AssignmentPlan.Deployment> deployments = List.of(deployment1, deployment2);
1167+
1168+
// Create source plan and assign models to nodes
1169+
AssignmentPlan.Builder sourceBuilder = AssignmentPlan.builder(nodes, deployments);
1170+
sourceBuilder.assignModelToNode(deployment1, node1, 1);
1171+
sourceBuilder.assignModelToNode(deployment1, node2, 1);
1172+
sourceBuilder.assignModelToNode(deployment2, node2, 1);
1173+
AssignmentPlan source = sourceBuilder.build();
1174+
1175+
// Create destination plan
1176+
AssignmentPlan.Builder dest = AssignmentPlan.builder(nodes, deployments);
1177+
1178+
// Create map of node IDs to original nodes
1179+
Map<String, AssignmentPlan.Node> originalNodeById = nodes.stream()
1180+
.collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
1181+
1182+
// Call copyAssignments
1183+
TrainedModelAssignmentRebalancer.copyAssignments(source, dest, originalNodeById);
1184+
1185+
// Build the destination plan
1186+
AssignmentPlan result = dest.build();
1187+
1188+
// Verify assignments
1189+
Optional<Map<AssignmentPlan.Node, Integer>> deployment1Assignments = result.assignments(deployment1);
1190+
assertThat(deployment1Assignments.isPresent(), is(true));
1191+
assertThat(deployment1Assignments.get().size(), equalTo(2));
1192+
assertThat(deployment1Assignments.get().get(node1), equalTo(1));
1193+
assertThat(deployment1Assignments.get().get(node2), equalTo(1));
1194+
1195+
Optional<Map<AssignmentPlan.Node, Integer>> deployment2Assignments = result.assignments(deployment2);
1196+
assertThat(deployment2Assignments.isPresent(), is(true));
1197+
assertThat(deployment2Assignments.get().size(), equalTo(1));
1198+
assertThat(deployment2Assignments.get().get(node2), equalTo(1));
1199+
}
1200+
11301201
private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
11311202
return lowPriorityParams(deploymentId, deploymentId, modelSize);
11321203
}

0 commit comments

Comments
 (0)