Skip to content

Commit 727200e

Browse files
committed
Update detectConflictFedPlan, resolveConflictFedPlan, and MemoTablePrinter
1 parent 07cdbbd commit 727200e

File tree

5 files changed

+306
-218
lines changed

5 files changed

+306
-218
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java

Lines changed: 32 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,9 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, F
7070
/**
7171
* Retrieves the minimum cost child plan considering the parent's output type.
7272
* The cost is calculated using getParentViewCost to account for potential type mismatches.
73-
*
74-
* @param childHopID ?
75-
* @param childFedOutType ?
76-
* @return ?
7773
*/
78-
public FedPlan getMinCostFedPlan(long hopID, FederatedOutput fedOutType) {
79-
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
74+
public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput> fedPlanPair) {
75+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair);
8076
return fedPlanVariantList._fedPlanVariants.stream()
8177
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
8278
.orElse(null);
@@ -86,12 +82,22 @@ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType
8682
return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
8783
}
8884

85+
public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> fedPlanPair) {
86+
return hopMemoTable.get(fedPlanPair);
87+
}
88+
8989
public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) {
9090
// Todo: Consider whether to verify if pruning has been performed
9191
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
9292
return fedPlanVariantList._fedPlanVariants.get(0);
9393
}
9494

95+
public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput> fedPlanPair) {
96+
// Todo: Consider whether to verify if pruning has been performed
97+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair);
98+
return fedPlanVariantList._fedPlanVariants.get(0);
99+
}
100+
95101
/**
96102
* Checks if the memo table contains an entry for a given Hop and fedOutType.
97103
*
@@ -104,128 +110,14 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
104110
}
105111

106112
/**
107-
* Prunes all entries in the memo table, retaining only the minimum-cost
108-
* FedPlan for each entry.
109-
*/
110-
public void pruneMemoTable() {
111-
for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> entry : hopMemoTable.entrySet()) {
112-
List<FedPlan> fedPlanList = entry.getValue().getFedPlanVariants();
113-
if (fedPlanList.size() > 1) {
114-
// Find the FedPlan with the minimum cost
115-
FedPlan minCostPlan = fedPlanList.stream()
116-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
117-
.orElse(null);
118-
119-
// Retain only the minimum cost plan
120-
fedPlanList.clear();
121-
fedPlanList.add(minCostPlan);
122-
}
123-
}
124-
}
125-
126-
// Todo: Separate print functions from FederatedMemoTable
127-
/**
128-
* Recursively prints a tree representation of the DAG starting from the given root FedPlan.
129-
* Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
113+
* Prunes the specified entry in the memo table, retaining only the minimum-cost
114+
* FedPlan for the given Hop ID and federated output type.
130115
*
131-
* @param rootFedPlan The starting point FedPlan to print
116+
* @param hopID The ID of the Hop to prune
117+
* @param federatedOutput The federated output type associated with the Hop
132118
*/
133-
public void printFedPlanTree(FedPlan rootFedPlan) {
134-
Set<FedPlan> visited = new HashSet<>();
135-
printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
136-
}
137-
138-
/**
139-
* Helper method to recursively print the FedPlan tree.
140-
*
141-
* @param plan The current FedPlan to print
142-
* @param visited Set to keep track of visited FedPlans (prevents cycles)
143-
* @param depth The current depth level for indentation
144-
* @param isLast Whether this node is the last child of its parent
145-
*/
146-
private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> visited, int depth, boolean isLast) {
147-
if (plan == null || visited.contains(plan)) {
148-
return;
149-
}
150-
151-
visited.add(plan);
152-
153-
Hop hop = plan.getHopRef();
154-
StringBuilder sb = new StringBuilder();
155-
156-
// Add FedPlan information
157-
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
158-
.append(plan.getHopRef().getOpString())
159-
.append(" [")
160-
.append(plan.getFedOutType())
161-
.append("]");
162-
163-
StringBuilder childs = new StringBuilder();
164-
childs.append(" (");
165-
boolean childAdded = false;
166-
for( Hop input : hop.getInput()){
167-
childs.append(childAdded?",":"");
168-
childs.append(input.getHopID());
169-
childAdded = true;
170-
}
171-
childs.append(")");
172-
if( childAdded )
173-
sb.append(childs.toString());
174-
175-
176-
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
177-
plan.getTotalCost(),
178-
plan.getSelfCost(),
179-
plan.getNetTransferCost()));
180-
181-
// Add matrix characteristics
182-
sb.append(" [")
183-
.append(hop.getDim1()).append(", ")
184-
.append(hop.getDim2()).append(", ")
185-
.append(hop.getBlocksize()).append(", ")
186-
.append(hop.getNnz());
187-
188-
if (hop.getUpdateType().isInPlace()) {
189-
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
190-
}
191-
sb.append("]");
192-
193-
// Add memory estimates
194-
sb.append(" [")
195-
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
196-
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
197-
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
198-
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
199-
200-
// Add reblock and checkpoint requirements
201-
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
202-
sb.append(" [rblk, chkpt]");
203-
} else if (hop.requiresReblock()) {
204-
sb.append(" [rblk]");
205-
} else if (hop.requiresCheckpoint()) {
206-
sb.append(" [chkpt]");
207-
}
208-
209-
// Add execution type
210-
if (hop.getExecType() != null) {
211-
sb.append(", ").append(hop.getExecType());
212-
}
213-
214-
System.out.println(sb);
215-
216-
// Process child nodes
217-
List<Pair<Long, FederatedOutput>> childRefs = plan.getChildFedPlans();
218-
for (int i = 0; i < childRefs.size(); i++) {
219-
Pair<Long, FederatedOutput> childRef = childRefs.get(i);
220-
FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight());
221-
if (childVariants == null || childVariants.getFedPlanVariants().isEmpty())
222-
continue;
223-
224-
boolean isLastChild = (i == childRefs.size() - 1);
225-
for (FedPlan childPlan : childVariants.getFedPlanVariants()) {
226-
printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild);
227-
}
228-
}
119+
public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
120+
hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune();
229121
}
230122

231123
/**
@@ -262,6 +154,20 @@ public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
262154

263155
public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);}
264156
public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;}
157+
public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
158+
159+
public void prune() {
160+
if (_fedPlanVariants.size() > 1) {
161+
// Find the FedPlan with the minimum cost
162+
FedPlan minCostPlan = _fedPlanVariants.stream()
163+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
164+
.orElse(null);
165+
166+
// Retain only the minimum cost plan
167+
_fedPlanVariants.clear();
168+
_fedPlanVariants.add(minCostPlan);
169+
}
170+
}
265171
}
266172

267173
/**
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package org.apache.sysds.hops.fedplanner;
2+
3+
import org.apache.commons.lang3.tuple.Pair;
4+
import org.apache.sysds.hops.Hop;
5+
import org.apache.sysds.hops.OptimizerUtils;
6+
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
7+
8+
import java.util.HashSet;
9+
import java.util.List;
10+
import java.util.Set;
11+
12+
public class FederatedMemoTablePrinter {
13+
/**
14+
* Recursively prints a tree representation of the DAG starting from the given root FedPlan.
15+
* Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
16+
* Additionally, prints the additional total cost once at the beginning.
17+
*
18+
* @param rootFedPlan The starting point FedPlan to print
19+
* @param memoTable The memoization table containing FedPlan variants
20+
* @param additionalTotalCost The additional cost to be printed once
21+
*/
22+
public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, FederatedMemoTable memoTable,
23+
double additionalTotalCost) {
24+
System.out.println("Additional Cost: " + additionalTotalCost);
25+
Set<FederatedMemoTable.FedPlan> visited = new HashSet<>();
26+
printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
27+
}
28+
29+
/**
30+
* Helper method to recursively print the FedPlan tree.
31+
*
32+
* @param plan The current FedPlan to print
33+
* @param visited Set to keep track of visited FedPlans (prevents cycles)
34+
* @param depth The current depth level for indentation
35+
*/
36+
private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable,
37+
Set<FederatedMemoTable.FedPlan> visited, int depth) {
38+
if (plan == null || visited.contains(plan)) {
39+
return;
40+
}
41+
42+
visited.add(plan);
43+
44+
Hop hop = plan.getHopRef();
45+
StringBuilder sb = new StringBuilder();
46+
47+
// Add FedPlan information
48+
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
49+
.append(plan.getHopRef().getOpString())
50+
.append(" [")
51+
.append(plan.getFedOutType())
52+
.append("]");
53+
54+
StringBuilder childs = new StringBuilder();
55+
childs.append(" (");
56+
boolean childAdded = false;
57+
for( Hop input : hop.getInput()){
58+
childs.append(childAdded?",":"");
59+
childs.append(input.getHopID());
60+
childAdded = true;
61+
}
62+
childs.append(")");
63+
if( childAdded )
64+
sb.append(childs.toString());
65+
66+
67+
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
68+
plan.getTotalCost(),
69+
plan.getSelfCost(),
70+
plan.getNetTransferCost()));
71+
72+
// Add matrix characteristics
73+
sb.append(" [")
74+
.append(hop.getDim1()).append(", ")
75+
.append(hop.getDim2()).append(", ")
76+
.append(hop.getBlocksize()).append(", ")
77+
.append(hop.getNnz());
78+
79+
if (hop.getUpdateType().isInPlace()) {
80+
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
81+
}
82+
sb.append("]");
83+
84+
// Add memory estimates
85+
sb.append(" [")
86+
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
87+
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
88+
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
89+
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
90+
91+
// Add reblock and checkpoint requirements
92+
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
93+
sb.append(" [rblk, chkpt]");
94+
} else if (hop.requiresReblock()) {
95+
sb.append(" [rblk]");
96+
} else if (hop.requiresCheckpoint()) {
97+
sb.append(" [chkpt]");
98+
}
99+
100+
// Add execution type
101+
if (hop.getExecType() != null) {
102+
sb.append(", ").append(hop.getExecType());
103+
}
104+
105+
System.out.println(sb);
106+
107+
// Process child nodes
108+
List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlanPairs = plan.getChildFedPlans();
109+
for (int i = 0; i < childFedPlanPairs.size(); i++) {
110+
Pair<Long, FEDInstruction.FederatedOutput> childFedPlanPair = childFedPlanPairs.get(i);
111+
FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair);
112+
if (childVariants == null || childVariants.isEmpty())
113+
continue;
114+
115+
for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) {
116+
printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1);
117+
}
118+
}
119+
}
120+
}

0 commit comments

Comments
 (0)