Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ private static AssignmentPlan mergePlans(
*/
static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
for (AssignmentPlan.Deployment deployment : source.deployments()) {
Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
Map<String, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
for (Map.Entry<String, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey());
dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
}
}
Expand Down Expand Up @@ -337,23 +337,23 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme
assignmentBuilder.setMaxAssignedAllocations(existingAssignment.getMaxAssignedAllocations());
}

Map<AssignmentPlan.Node, Integer> assignments = assignmentPlan.assignments(deployment).orElseGet(Map::of);
for (Map.Entry<AssignmentPlan.Node, Integer> assignment : assignments.entrySet()) {
if (existingAssignment != null && existingAssignment.isRoutedToNode(assignment.getKey().id())) {
RoutingInfo existingRoutingInfo = existingAssignment.getNodeRoutingTable().get(assignment.getKey().id());
Map<String, Integer> assignments = assignmentPlan.assignments(deployment).orElseGet(Map::of);
for (Map.Entry<String, Integer> assignment : assignments.entrySet()) {
if (existingAssignment != null && existingAssignment.isRoutedToNode(assignment.getKey())) {
RoutingInfo existingRoutingInfo = existingAssignment.getNodeRoutingTable().get(assignment.getKey());
RoutingState state = existingRoutingInfo.getState();
String reason = existingRoutingInfo.getReason();
if (state == RoutingState.FAILED) {
state = RoutingState.STARTING;
reason = "";
}
assignmentBuilder.addRoutingEntry(
assignment.getKey().id(),
assignment.getKey(),
new RoutingInfo(existingRoutingInfo.getCurrentAllocations(), assignment.getValue(), state, reason)
);
} else {
assignmentBuilder.addRoutingEntry(
assignment.getKey().id(),
assignment.getKey(),
new RoutingInfo(assignment.getValue(), assignment.getValue(), RoutingState.STARTING, "")
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ AssignmentPlan mergePreservedAllocations(AssignmentPlan assignmentPlan) {
// with its preserved allocations.
final Map<Tuple<String, String>, Integer> plannedAssignmentsByDeploymentNodeIdPair = new HashMap<>();
for (Deployment d : assignmentPlan.deployments()) {
Map<Node, Integer> assignmentsOfDeployment = assignmentPlan.assignments(d).orElse(Map.of());
for (Map.Entry<Node, Integer> nodeAssignment : assignmentsOfDeployment.entrySet()) {
Map<String, Integer> assignmentsOfDeployment = assignmentPlan.assignments(d).orElse(Map.of());
for (Map.Entry<String, Integer> nodeAssignment : assignmentsOfDeployment.entrySet()) {
plannedAssignmentsByDeploymentNodeIdPair.put(
Tuple.tuple(d.deploymentId(), nodeAssignment.getKey().id()),
Tuple.tuple(d.deploymentId(), nodeAssignment.getKey()),
nodeAssignment.getValue()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,15 @@ public Set<Deployment> deployments() {
* @param deployment the model for which assignments are returned
* @return the model assignments per node. The Optional will be empty if the model has no assignments.
*/
public Optional<Map<Node, Integer>> assignments(Deployment deployment) {
public Optional<Map<String, Integer>> assignments(Deployment deployment) {
Map<Node, Integer> modelAssignments = assignments.get(deployment);
return (modelAssignments == null || modelAssignments.isEmpty()) ? Optional.empty() : Optional.of(modelAssignments);
if (modelAssignments == null || modelAssignments.isEmpty()) {
return Optional.empty();
}
Map<String, Integer> byNodeId = modelAssignments.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue));
return Optional.of(byNodeId);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat

Map<String, String> modelIdToNodeIdWithSingleAllocation = new HashMap<>();
for (AssignmentPlan.Deployment m : planWithSingleAllocationForPreviouslyAssignedModels.deployments()) {
Optional<Map<Node, Integer>> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m);
Set<Node> nodes = assignments.orElse(Map.of()).keySet();
Optional<Map<String, Integer>> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m);
Set<String> nodes = assignments.orElse(Map.of()).keySet();
if (nodes.isEmpty() == false) {
assert nodes.size() == 1;
modelIdToNodeIdWithSingleAllocation.put(m.deploymentId(), nodes.iterator().next().id());
modelIdToNodeIdWithSingleAllocation.put(m.deploymentId(), nodes.iterator().next());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ private AssignmentPlan swapOriginalDeploymentsInPlan(

for (AssignmentPlan.Deployment planDeployment : planDeployments) {
AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.deploymentId());
Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
Node originalNode = originalNodeById.get(assignment.getKey().id());
Map<String, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
for (Map.Entry<String, Integer> assignment : nodeAssignments.entrySet()) {
Node originalNode = originalNodeById.get(assignment.getKey());
finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
}
}
Expand All @@ -223,11 +223,11 @@ private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId
for (AssignmentPlan plan : plans) {
for (AssignmentPlan.Deployment m : plan.deployments()) {
Map<String, Integer> nodeIdToAllocations = allocationsByNodeIdByDeploymentId.get(m.deploymentId());
Optional<Map<Node, Integer>> assignments = plan.assignments(m);
Optional<Map<String, Integer>> assignments = plan.assignments(m);
if (assignments.isPresent()) {
for (Map.Entry<Node, Integer> nodeAssignments : assignments.get().entrySet()) {
for (Map.Entry<String, Integer> nodeAssignments : assignments.get().entrySet()) {
nodeIdToAllocations.compute(
nodeAssignments.getKey().id(),
nodeAssignments.getKey(),
(nodeId, existingAllocations) -> existingAllocations == null
? nodeAssignments.getValue()
: existingAllocations + nodeAssignments.getValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() {

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(true));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));
}
{ // new memory format
AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
Expand Down Expand Up @@ -107,7 +107,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() {

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(true));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));
}
}

Expand Down Expand Up @@ -140,7 +140,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() {

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(true));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));
}
{ // new memory format
AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
Expand Down Expand Up @@ -169,7 +169,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() {

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(true));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));

}
}
Expand All @@ -195,7 +195,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment()

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(false));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));
}
{
// new memory format
Expand Down Expand Up @@ -229,7 +229,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment()

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(false));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 1)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 1)));
}
}

Expand Down Expand Up @@ -365,7 +365,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() {

assertThat(plan.deployments(), contains(m));
assertThat(plan.satisfiesCurrentAssignments(), is(true));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 3)));
assertThat(plan.assignments(m).get(), equalTo(Map.of(n.id(), 3)));
}

public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() {
Expand Down
Loading