diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index f523b4b086f35..8fa6cdfe94438 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -135,9 +135,9 @@ private static AssignmentPlan mergePlans( */ static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map originalNodeById) { for (AssignmentPlan.Deployment deployment : source.deployments()) { - Map sourceNodeAssignments = source.assignments(deployment).orElse(Map.of()); - for (Map.Entry sourceAssignment : sourceNodeAssignments.entrySet()) { - AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id()); + Map sourceNodeAssignments = source.assignments(deployment).orElse(Map.of()); + for (Map.Entry sourceAssignment : sourceNodeAssignments.entrySet()) { + AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey()); dest.assignModelToNode(deployment, node, sourceAssignment.getValue()); } } @@ -337,10 +337,10 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme assignmentBuilder.setMaxAssignedAllocations(existingAssignment.getMaxAssignedAllocations()); } - Map assignments = assignmentPlan.assignments(deployment).orElseGet(Map::of); - for (Map.Entry assignment : assignments.entrySet()) { - if (existingAssignment != null && existingAssignment.isRoutedToNode(assignment.getKey().id())) { - RoutingInfo existingRoutingInfo = existingAssignment.getNodeRoutingTable().get(assignment.getKey().id()); + Map assignments = assignmentPlan.assignments(deployment).orElseGet(Map::of); + for (Map.Entry assignment : assignments.entrySet()) { + if (existingAssignment != null && existingAssignment.isRoutedToNode(assignment.getKey())) { + RoutingInfo existingRoutingInfo = existingAssignment.getNodeRoutingTable().get(assignment.getKey()); RoutingState state = existingRoutingInfo.getState(); String reason = existingRoutingInfo.getReason(); if (state == RoutingState.FAILED) { @@ -348,12 +348,12 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme reason = ""; } assignmentBuilder.addRoutingEntry( - assignment.getKey().id(), + assignment.getKey(), new RoutingInfo(existingRoutingInfo.getCurrentAllocations(), assignment.getValue(), state, reason) ); } else { assignmentBuilder.addRoutingEntry( - assignment.getKey().id(), + assignment.getKey(), new RoutingInfo(assignment.getValue(), assignment.getValue(), RoutingState.STARTING, "") ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 8a0bbe2ecdd5e..ac3d29aad130d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -74,10 +74,10 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) { // with its preserved allocations. final Map, Integer> plannedAssignmentsByDeploymentNodeIdPair = new HashMap<>(); for (Deployment d : assignmentPlan.deployments()) { - Map assignmentsOfDeployment = assignmentPlan.assignments(d).orElse(Map.of()); - for (Map.Entry nodeAssignment : assignmentsOfDeployment.entrySet()) { + Map assignmentsOfDeployment = assignmentPlan.assignments(d).orElse(Map.of()); + for (Map.Entry nodeAssignment : assignmentsOfDeployment.entrySet()) { plannedAssignmentsByDeploymentNodeIdPair.put( - Tuple.tuple(d.deploymentId(), nodeAssignment.getKey().id()), + Tuple.tuple(d.deploymentId(), nodeAssignment.getKey()), nodeAssignment.getValue() ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index a90a8cb9d5262..7bb616bb1f92e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -213,9 +213,15 @@ public Set deployments() { * @param deployment the model for which assignments are returned * @return the model assignments per node. The Optional will be empty if the model has no assignments. */ - public Optional> assignments(Deployment deployment) { + public Optional> assignments(Deployment deployment) { Map modelAssignments = assignments.get(deployment); - return (modelAssignments == null || modelAssignments.isEmpty()) ? Optional.empty() : Optional.of(modelAssignments); + if (modelAssignments == null || modelAssignments.isEmpty()) { + return Optional.empty(); + } + Map byNodeId = modelAssignments.entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); + return Optional.of(byNodeId); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index bb7998035ff46..ee6250dc88834 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -157,11 +157,11 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat Map modelIdToNodeIdWithSingleAllocation = new HashMap<>(); for (AssignmentPlan.Deployment m : planWithSingleAllocationForPreviouslyAssignedModels.deployments()) { - Optional> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m); - Set nodes = assignments.orElse(Map.of()).keySet(); + Optional> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m); + Set nodes = assignments.orElse(Map.of()).keySet(); if (nodes.isEmpty() == false) { assert nodes.size() == 1; - modelIdToNodeIdWithSingleAllocation.put(m.deploymentId(), nodes.iterator().next().id()); + modelIdToNodeIdWithSingleAllocation.put(m.deploymentId(), nodes.iterator().next()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 64cd40fdc537d..304fbde68a867 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -208,9 +208,9 @@ private AssignmentPlan swapOriginalDeploymentsInPlan( for (AssignmentPlan.Deployment planDeployment : planDeployments) { AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId()); - Map nodeAssignments = plan.assignments(planDeployment).orElse(Map.of()); - for (Map.Entry assignment : nodeAssignments.entrySet()) { - Node originalNode = originalNodeById.get(assignment.getKey().id()); + Map nodeAssignments = plan.assignments(planDeployment).orElse(Map.of()); + for (Map.Entry assignment : nodeAssignments.entrySet()) { + Node originalNode = originalNodeById.get(assignment.getKey()); finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue()); } } @@ -223,11 +223,11 @@ private Map> mergeAllocationsByNodeIdByDeploymentId for (AssignmentPlan plan : plans) { for (AssignmentPlan.Deployment m : plan.deployments()) { Map nodeIdToAllocations = allocationsByNodeIdByDeploymentId.get(m.deploymentId()); - Optional> assignments = plan.assignments(m); + Optional> assignments = plan.assignments(m); if (assignments.isPresent()) { - for (Map.Entry nodeAssignments : assignments.get().entrySet()) { + for (Map.Entry nodeAssignments : assignments.get().entrySet()) { nodeIdToAllocations.compute( - nodeAssignments.getKey().id(), + nodeAssignments.getKey(), (nodeId, existingAllocations) -> existingAllocations == null ? nodeAssignments.getValue() : existingAllocations + nodeAssignments.getValue() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index c7f166a19bb69..6d7a58315fd0a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -73,7 +73,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } { // new memory format AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( @@ -107,7 +107,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } } @@ -140,7 +140,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } { // new memory format AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( @@ -169,7 +169,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } } @@ -195,7 +195,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } { // new memory format @@ -229,7 +229,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(false)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1))); } } @@ -365,7 +365,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { assertThat(plan.deployments(), contains(m)); assertThat(plan.satisfiesCurrentAssignments(), is(true)); - assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 3))); + assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 3))); } public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index 2a5b9839f80c3..4d1ba5a1ba03f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -209,11 +209,11 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); - Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) != null) { - assertThat(assignments.get(node1), equalTo(4)); + Map assignments = plan.assignments(deployment).get(); + if (assignments.get(node1.id()) != null) { + assertThat(assignments.get(node1.id()), equalTo(4)); } else { - assertThat(assignments.get(node2), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(4)); } } @@ -235,11 +235,11 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); - Map assignments = plan.assignments(deployment).get(); - if (assignments.get(node1) != null) { - assertThat(assignments.get(node1), equalTo(4)); + Map assignments = plan.assignments(deployment).get(); + if (assignments.get(node1.id()) != null) { + assertThat(assignments.get(node1.id()), equalTo(4)); } else { - assertThat(assignments.get(node2), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(4)); } } @@ -261,8 +261,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node = new Node("n_1", scaleNodeSize(100), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(4)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node.id()), equalTo(4)); } // Two nodes { @@ -270,9 +270,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node2 = new Node("n_2", scaleNodeSize(100), 2); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(2)); } // Three nodes { @@ -281,10 +281,10 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node3 = new Node("n_3", scaleNodeSize(100), 3); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(3)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(2)); + assertThat(assignments.get(node3.id()), equalTo(3)); } } @@ -306,8 +306,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(4)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node.id()), equalTo(4)); } // Two nodes { @@ -315,9 +315,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node2 = new Node("n_2", ByteSizeValue.ofMb(600).getBytes(), 2); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(2)); } // Three nodes { @@ -326,10 +326,10 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA Node node3 = new Node("n_3", ByteSizeValue.ofMb(700).getBytes(), 3); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(4)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(3)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(4)); + assertThat(assignments.get(node2.id()), equalTo(2)); + assertThat(assignments.get(node3.id()), equalTo(3)); } } @@ -350,41 +350,41 @@ public void testMultipleDeploymentsAndNodesWithSingleSolution() { { assertThat(plan.assignments(deployment1).isPresent(), is(true)); - Map assignments = plan.assignments(deployment1).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); + Map assignments = plan.assignments(deployment1).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(1)); + assertThat(assignments.get(node3.id()), is(nullValue())); + assertThat(assignments.get(node4.id()), is(nullValue())); } { assertThat(plan.assignments(deployment2).isPresent(), is(true)); - Map assignments = plan.assignments(deployment2).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); + Map assignments = plan.assignments(deployment2).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(1)); + assertThat(assignments.get(node3.id()), is(nullValue())); + assertThat(assignments.get(node4.id()), is(nullValue())); } { assertThat(plan.assignments(deployment3).isPresent(), is(true)); - Map assignments = plan.assignments(deployment3).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); + Map assignments = plan.assignments(deployment3).get(); + assertThat(assignments.get(node1.id()), is(nullValue())); + assertThat(assignments.get(node2.id()), is(nullValue())); // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(1)); - assertThat(assignments.get(otherNode), is(nullValue())); + String assignedNodeId = assignments.get(node3.id()) != null ? node3.id() : node4.id(); + String otherNodeId = assignedNodeId.equals(node3.id()) ? node4.id() : node3.id(); + assertThat(assignments.get(assignedNodeId), equalTo(1)); + assertThat(assignments.get(otherNodeId), is(nullValue())); } { assertThat(plan.assignments(deployment4).isPresent(), is(true)); - Map assignments = plan.assignments(deployment4).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); + Map assignments = plan.assignments(deployment4).get(); + assertThat(assignments.get(node1.id()), is(nullValue())); + assertThat(assignments.get(node2.id()), is(nullValue())); // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(2)); - assertThat(assignments.get(otherNode), is(nullValue())); + String assignedNodeId = assignments.get(node3.id()) != null ? node3.id() : node4.id(); + String otherNodeId = assignedNodeId.equals(node3.id()) ? node4.id() : node3.id(); + assertThat(assignments.get(assignedNodeId), equalTo(2)); + assertThat(assignments.get(otherNodeId), is(nullValue())); } } @@ -449,41 +449,41 @@ public void testMultipleDeploymentsAndNodesWithSingleSolution_NewMemoryFields() { assertThat(plan.assignments(deployment1).isPresent(), is(true)); - Map assignments = plan.assignments(deployment1).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); + Map assignments = plan.assignments(deployment1).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(1)); + assertThat(assignments.get(node3.id()), is(nullValue())); + assertThat(assignments.get(node4.id()), is(nullValue())); } { assertThat(plan.assignments(deployment2).isPresent(), is(true)); - Map assignments = plan.assignments(deployment2).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); - assertThat(assignments.get(node4), is(nullValue())); + Map assignments = plan.assignments(deployment2).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(1)); + assertThat(assignments.get(node3.id()), is(nullValue())); + assertThat(assignments.get(node4.id()), is(nullValue())); } { assertThat(plan.assignments(deployment3).isPresent(), is(true)); - Map assignments = plan.assignments(deployment3).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); + Map assignments = plan.assignments(deployment3).get(); + assertThat(assignments.get(node1.id()), is(nullValue())); + assertThat(assignments.get(node2.id()), is(nullValue())); // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(1)); - assertThat(assignments.get(otherNode), is(nullValue())); + String assignedNodeId = assignments.get(node3.id()) != null ? node3.id() : node4.id(); + String otherNodeId = assignedNodeId.equals(node3.id()) ? node4.id() : node3.id(); + assertThat(assignments.get(assignedNodeId), equalTo(1)); + assertThat(assignments.get(otherNodeId), is(nullValue())); } { assertThat(plan.assignments(deployment4).isPresent(), is(true)); - Map assignments = plan.assignments(deployment4).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); + Map assignments = plan.assignments(deployment4).get(); + assertThat(assignments.get(node1.id()), is(nullValue())); + assertThat(assignments.get(node2.id()), is(nullValue())); // Will either be on node 3 or 4 - Node assignedNode = assignments.get(node3) != null ? node3 : node4; - Node otherNode = assignedNode.equals(node3) ? node4 : node3; - assertThat(assignments.get(assignedNode), equalTo(2)); - assertThat(assignments.get(otherNode), is(nullValue())); + String assignedNodeId = assignments.get(node3.id()) != null ? node3.id() : node4.id(); + String otherNodeId = assignedNodeId.equals(node3.id()) ? node4.id() : node3.id(); + assertThat(assignments.get(assignedNodeId), equalTo(2)); + assertThat(assignments.get(otherNodeId), is(nullValue())); } } @@ -505,8 +505,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node = new Node("n_1", scaleNodeSize(100), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(1)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node.id()), equalTo(1)); } // Two nodes { @@ -514,9 +514,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node2 = new Node("n_2", scaleNodeSize(100), 8); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(2)); } // Three nodes { @@ -525,10 +525,10 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node3 = new Node("n_3", scaleNodeSize(100), 15); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(5)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(2)); + assertThat(assignments.get(node3.id()), equalTo(5)); } } @@ -550,8 +550,8 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node = new Node("n_1", ByteSizeValue.ofMb(800).getBytes(), 4); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node), equalTo(1)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node.id()), equalTo(1)); } // Two nodes { @@ -559,9 +559,9 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node2 = new Node("n_2", ByteSizeValue.ofMb(800).getBytes(), 8); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(2)); } // Three nodes { @@ -570,10 +570,10 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA Node node3 = new Node("n_3", ByteSizeValue.ofMb(800).getBytes(), 15); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment)).computePlan(); assertThat(assignmentPlan.assignments(deployment).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment).get(); - assertThat(assignments.get(node1), equalTo(1)); - assertThat(assignments.get(node2), equalTo(2)); - assertThat(assignments.get(node3), equalTo(5)); + Map assignments = assignmentPlan.assignments(deployment).get(); + assertThat(assignments.get(node1.id()), equalTo(1)); + assertThat(assignments.get(node2.id()), equalTo(2)); + assertThat(assignments.get(node3.id()), equalTo(5)); } } @@ -594,7 +594,7 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment).isPresent(), is(true)); - assertThat(plan.assignments(deployment).get(), equalTo(Map.of(node, 4))); + assertThat(plan.assignments(deployment).get(), equalTo(Map.of("n_1", 4))); } public void testFullCoreUtilization_GivenDeploymentsWithSingleThreadPerAllocation() { @@ -625,7 +625,7 @@ public void testFullCoreUtilization_GivenDeploymentsWithSingleThreadPerAllocatio int usedCores = 0; for (AssignmentPlan.Deployment m : deployments) { - Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); + Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); usedCores += assignments.values().stream().mapToInt(Integer::intValue).sum(); } assertThat(usedCores, equalTo(64)); @@ -728,7 +728,7 @@ public void testFullCoreUtilization_GivenDeploymentsWithSingleThreadPerAllocatio int usedCores = 0; for (AssignmentPlan.Deployment m : deployments) { - Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); + Map assignments = assignmentPlan.assignments(m).orElse(Map.of()); usedCores += assignments.values().stream().mapToInt(Integer::intValue).sum(); } assertThat(usedCores, equalTo(64)); @@ -819,7 +819,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode List previousModelsPlusNew = new ArrayList<>(deployments.size() + 1); for (Deployment m : deployments) { - Map assignments = originalPlan.assignments(m).orElse(Map.of()); + Map assignments = originalPlan.assignments(m).orElse(Map.of()); Map previousAssignments = assignments.entrySet() .stream() .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); @@ -869,17 +869,17 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss assertThat(assignmentPlan.getRemainingNodeMemory("n_3"), greaterThanOrEqualTo(0L)); { assertThat(assignmentPlan.assignments(deployment1).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment1).get(); - assertThat(assignments.get(node1), equalTo(2)); - assertThat(assignments.get(node2), equalTo(1)); - assertThat(assignments.get(node3), is(nullValue())); + Map assignments = assignmentPlan.assignments(deployment1).get(); + assertThat(assignments.get(node1.id()), equalTo(2)); + assertThat(assignments.get(node2.id()), equalTo(1)); + assertThat(assignments.get(node3.id()), is(nullValue())); } { assertThat(assignmentPlan.assignments(deployment2).isPresent(), is(true)); - Map assignments = assignmentPlan.assignments(deployment2).get(); - assertThat(assignments.get(node1), is(nullValue())); - assertThat(assignments.get(node2), is(nullValue())); - assertThat(assignments.get(node3), equalTo(2)); + Map assignments = assignmentPlan.assignments(deployment2).get(); + assertThat(assignments.get(node1.id()), is(nullValue())); + assertThat(assignments.get(node2.id()), is(nullValue())); + assertThat(assignments.get(node3.id()), equalTo(2)); } } @@ -1171,14 +1171,8 @@ public void testGivenClusterResize_ShouldRemoveAllocatedDeployments_NewMemoryFie public static List createDeploymentsFromPlan(AssignmentPlan plan) { List deployments = new ArrayList<>(); for (Deployment m : plan.deployments()) { - Optional> assignments = plan.assignments(m); - Map currentAllocations = Map.of(); - if (assignments.isPresent()) { - currentAllocations = new HashMap<>(); - for (Map.Entry nodeAssignments : assignments.get().entrySet()) { - currentAllocations.put(nodeAssignments.getKey().id(), nodeAssignments.getValue()); - } - } + Optional> assignments = plan.assignments(m); + Map currentAllocations = assignments.orElse(Map.of()); int totalAllocations = currentAllocations.values().stream().mapToInt(Integer::intValue).sum(); deployments.add( new Deployment( @@ -1201,21 +1195,18 @@ public static List createDeploymentsFromPlan(AssignmentPlan plan) { public static Map> convertToIdIndexed(AssignmentPlan plan) { Map> result = new HashMap<>(); for (AssignmentPlan.Deployment m : plan.deployments()) { - Optional> assignments = plan.assignments(m); - Map allocationsPerNodeId = assignments.isPresent() ? new HashMap<>() : Map.of(); - for (Map.Entry nodeAssignments : assignments.orElse(Map.of()).entrySet()) { - allocationsPerNodeId.put(nodeAssignments.getKey().id(), nodeAssignments.getValue()); - } + Optional> assignments = plan.assignments(m); + Map allocationsPerNodeId = assignments.orElse(Map.of()); result.put(m.deploymentId(), allocationsPerNodeId); } return result; } public static void assertModelFullyAssignedToNode(AssignmentPlan plan, Deployment m, Node n) { - Optional> assignments = plan.assignments(m); + Optional> assignments = plan.assignments(m); assertThat(assignments.isPresent(), is(true)); assertThat(assignments.get().size(), equalTo(1)); - assertThat(assignments.get().get(n), equalTo(m.allocations())); + assertThat(assignments.get().get(n.id()), equalTo(m.allocations())); } public static List randomNodes(int scale) { @@ -1281,12 +1272,12 @@ public static Deployment randomModel(String idSuffix) { public static void assertPreviousAssignmentsAreSatisfied(List deployments, AssignmentPlan assignmentPlan) { for (Deployment m : deployments.stream().filter(m -> m.currentAllocationsByNodeId().isEmpty() == false).toList()) { - Map assignments = assignmentPlan.assignments(m).get(); + Map assignments = assignmentPlan.assignments(m).get(); Set assignedNodeIds = new HashSet<>(); int allocations = 0; - for (Map.Entry e : assignments.entrySet()) { - assignedNodeIds.add(e.getKey().id()); - if (m.currentAllocationsByNodeId().containsKey(e.getKey().id())) { + for (Map.Entry e : assignments.entrySet()) { + assignedNodeIds.add(e.getKey()); + if (m.currentAllocationsByNodeId().containsKey(e.getKey())) { assertThat(e.getValue(), greaterThanOrEqualTo(1)); } allocations += e.getValue(); @@ -1296,20 +1287,6 @@ public static void assertPreviousAssignmentsAreSatisfied(List nodes = new ArrayList<>(); - for (int i = 0; i < nodesSize; i++) { - nodes.add(new Node("n_" + i, ByteSizeValue.ofGb(6).getBytes(), 100)); - } - List deployments = new ArrayList<>(); - for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, "m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, null, 0, 0)); - } - - // Check plan is computed without OOM exception - new AssignmentPlanner(nodes, deployments).computePlan(); - } - private static Quality computeQuality(List nodes, List deployments, AssignmentPlan assignmentPlan) { final int totalCores = nodes.stream().map(Node::cores).mapToInt(Integer::intValue).sum(); final int totalAllocationRequired = deployments.stream() @@ -1319,7 +1296,7 @@ private static Quality computeQuality(List nodes, List deploym int usedCores = 0; int assignedAllocations = 0; for (Deployment m : deployments) { - for (Map.Entry assignment : assignmentPlan.assignments(m).orElse(Map.of()).entrySet()) { + for (Map.Entry assignment : assignmentPlan.assignments(m).orElse(Map.of()).entrySet()) { assignedAllocations += assignment.getValue(); usedCores += assignment.getValue() * m.threadsPerAllocation(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index d22394ec86a77..672fd6d3b8b98 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -103,12 +103,12 @@ public void testGivenPreviousAssignments() { AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 2) .build(); - assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1, 2))); + assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1.id(), 2))); assertThat(plan.assignments(deployment2), isEmpty()); plan = preserveAllAllocations.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2), isPresentWith(Map.of(node1, 1, node2, 2))); + assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1.id(), 3))); + assertThat(plan.assignments(deployment2), isPresentWith(Map.of(node1.id(), 1, node2.id(), 2))); // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); @@ -192,12 +192,12 @@ public void testGivenPreviousAssignments() { AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 2) .build(); - assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1, 2))); + assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1.id(), 2))); assertThat(plan.assignments(deployment2), isEmpty()); plan = preserveAllAllocations.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2), isPresentWith(Map.of(node1, 1, node2, 2))); + assertThat(plan.assignments(deployment1), isPresentWith(Map.of(node1.id(), 3))); + assertThat(plan.assignments(deployment2), isPresentWith(Map.of(node1.id(), 1, node2.id(), 2))); // 1000 - ((30 + 300 + 3*10) + (50 + 300 + 10)) = 280 : deployments use 720 MB on the node 1 assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); @@ -219,7 +219,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments assertThat(plan.assignments(deployment), isEmpty()); plan = preserveAllAllocations.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment), isPresentWith(Map.of(node, 2))); + assertThat(plan.assignments(deployment), isPresentWith(Map.of(node.id(), 2))); assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(0)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index 6f340900276ff..67d394df34be8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -123,13 +123,13 @@ public void testGivenPreviousAssignments() { .assignModelToNode(deployment1, node1, 2) .assignModelToNode(deployment2, node2, 1) .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of("n_1", 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of("n_2", 1))); plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of("n_1", 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of("n_1", 1, "n_2", 2))); // Node 1 already had deployments 1 and 2 assigned to it so adding more allocation doesn't change memory usage. assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(0L)); // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node @@ -213,13 +213,13 @@ public void testGivenPreviousAssignments() { .assignModelToNode(deployment1, node1, 2) .assignModelToNode(deployment2, node2, 1) .build(); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 2))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node2, 1))); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of("n_1", 2))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of("n_2", 1))); plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment1).get(), equalTo(Map.of(node1, 3))); - assertThat(plan.assignments(deployment2).get(), equalTo(Map.of(node1, 1, node2, 2))); + assertThat(plan.assignments(deployment1).get(), equalTo(Map.of("n_1", 3))); + assertThat(plan.assignments(deployment2).get(), equalTo(Map.of("n_1", 1, "n_2", 2))); // 1000 - [(30+300+3*10) + (50+300+10)] = 280 : deployments use 720MB on the node assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(280).getBytes())); // 8 - ((1*1+1*4) + 2*1) = 1 : deployments use 7 cores on the node @@ -243,7 +243,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments assertThat(plan.assignments(deployment), isEmpty()); plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment), isPresentWith(Map.of(node, 1))); + assertThat(plan.assignments(deployment), isPresentWith(Map.of("n_1", 1))); // 400 - (30*2 + 240) = 100 : deployments use 300MB on the node assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(100).getBytes())); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); @@ -269,7 +269,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments assertThat(plan.assignments(deployment), isEmpty()); plan = preserveOneAllocation.mergePreservedAllocations(plan); - assertThat(plan.assignments(deployment), isPresentWith(Map.of(node, 1))); + assertThat(plan.assignments(deployment), isPresentWith(Map.of("n_1", 1))); // 400 - (30 + 300 + 10) = 60 : deployments use 340MB on the node assertThat(plan.getRemainingNodeMemory("n_1"), equalTo(ByteSizeValue.ofMb(60).getBytes())); assertThat(plan.getRemainingNodeCores("n_1"), equalTo(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index ea23bfde0d848..83f444e71b557 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.ml.inference.assignment.planning; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Deployment; import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node; @@ -109,9 +108,9 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { assertThat(plan.satisfiesAllModels(), is(true)); assertThat(plan.assignments(deployment).isPresent(), is(true)); - Map assignments = plan.assignments(deployment).get(); + Map assignments = plan.assignments(deployment).get(); assertThat(assignments.keySet(), hasSize(1)); - assertThat(assignments.get(assignments.keySet().iterator().next()), equalTo(1)); + assertThat(assignments.values().iterator().next(), equalTo(1)); } public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { @@ -225,24 +224,27 @@ public void testGivenThreeDeployments_TwoNodesPerZone_ThreeZones_FullyFit() { { assertThat(plan.assignments(deployment1).isPresent(), is(true)); - Map assignments = plan.assignments(deployment1).get(); + Map assignments = plan.assignments(deployment1).get(); for (List zoneNodes : nodesByZone.values()) { - assertThat(Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet())), is(true)); + var zoneIds = zoneNodes.stream().map(Node::id).collect(Collectors.toSet()); + assertThat(assignments.keySet().stream().anyMatch(zoneIds::contains), is(true)); } } { assertThat(plan.assignments(deployment2).isPresent(), is(true)); - Map assignments = plan.assignments(deployment2).get(); + Map assignments = plan.assignments(deployment2).get(); for (List zoneNodes : nodesByZone.values()) { - assertThat(Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet())), is(true)); + var zoneIds = zoneNodes.stream().map(Node::id).collect(Collectors.toSet()); + assertThat(assignments.keySet().stream().anyMatch(zoneIds::contains), is(true)); } } { assertThat(plan.assignments(deployment3).isPresent(), is(true)); - Map assignments = plan.assignments(deployment3).get(); + Map assignments = plan.assignments(deployment3).get(); int zonesWithAllocations = 0; for (List zoneNodes : nodesByZone.values()) { - if (Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet()))) { + var zoneIds = zoneNodes.stream().map(Node::id).collect(Collectors.toSet()); + if (assignments.keySet().stream().anyMatch(zoneIds::contains)) { zonesWithAllocations++; } } @@ -281,10 +283,8 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode List previousModelsPlusNew = new ArrayList<>(deployments.size() + 1); for (AssignmentPlan.Deployment m : deployments) { - Map assignments = originalPlan.assignments(m).orElse(Map.of()); - Map previousAssignments = assignments.entrySet() - .stream() - .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue)); + Map assignments = originalPlan.assignments(m).orElse(Map.of()); + Map previousAssignments = assignments; previousModelsPlusNew.add( new AssignmentPlan.Deployment( m.deploymentId(),