Skip to content

Commit c5ab81c

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3790] Extraction of optimal FedPlans and conflict handling
Closes #2175.
1 parent 242a3c9 commit c5ab81c

File tree

8 files changed

+507
-160
lines changed

8 files changed

+507
-160
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java

Lines changed: 71 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.sysds.hops.fedplanner;
2121

2222
import org.apache.sysds.hops.Hop;
23-
import org.apache.sysds.hops.OptimizerUtils;
2423
import org.apache.commons.lang3.tuple.Pair;
2524
import org.apache.commons.lang3.tuple.ImmutablePair;
2625
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
@@ -29,8 +28,6 @@
2928
import java.util.List;
3029
import java.util.ArrayList;
3130
import java.util.Map;
32-
import java.util.HashSet;
33-
import java.util.Set;
3431

3532
/**
3633
* A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
@@ -71,12 +68,11 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, F
7168
* Retrieves the minimum cost child plan considering the parent's output type.
7269
* The cost is calculated using getParentViewCost to account for potential type mismatches.
7370
*
74-
* @param childHopID ?
75-
* @param childFedOutType ?
76-
* @return ?
71+
* @param fedPlanPair ???
72+
* @return min cost fed plan
7773
*/
78-
public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput childFedOutType) {
79-
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(childHopID, childFedOutType));
74+
public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput> fedPlanPair) {
75+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair);
8076
return fedPlanVariantList._fedPlanVariants.stream()
8177
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
8278
.orElse(null);
@@ -86,6 +82,22 @@ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType
8682
return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
8783
}
8884

85+
public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> fedPlanPair) {
86+
return hopMemoTable.get(fedPlanPair);
87+
}
88+
89+
public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) {
90+
// Todo: Consider whether to verify if pruning has been performed
91+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
92+
return fedPlanVariantList._fedPlanVariants.get(0);
93+
}
94+
95+
public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput> fedPlanPair) {
96+
// Todo: Consider whether to verify if pruning has been performed
97+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair);
98+
return fedPlanVariantList._fedPlanVariants.get(0);
99+
}
100+
89101
/**
90102
* Checks if the memo table contains an entry for a given Hop and fedOutType.
91103
*
@@ -98,162 +110,77 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
98110
}
99111

100112
/**
101-
* Prunes all entries in the memo table, retaining only the minimum-cost
102-
* FedPlan for each entry.
103-
*/
104-
public void pruneMemoTable() {
105-
for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> entry : hopMemoTable.entrySet()) {
106-
List<FedPlan> fedPlanList = entry.getValue().getFedPlanVariants();
107-
if (fedPlanList.size() > 1) {
108-
// Find the FedPlan with the minimum cost
109-
FedPlan minCostPlan = fedPlanList.stream()
110-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
111-
.orElse(null);
112-
113-
// Retain only the minimum cost plan
114-
fedPlanList.clear();
115-
fedPlanList.add(minCostPlan);
116-
}
117-
}
118-
}
119-
120-
/**
121-
* Recursively prints a tree representation of the DAG starting from the given root FedPlan.
122-
* Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
113+
* Prunes the specified entry in the memo table, retaining only the minimum-cost
114+
* FedPlan for the given Hop ID and federated output type.
123115
*
124-
* @param rootFedPlan The starting point FedPlan to print
116+
* @param hopID The ID of the Hop to prune
117+
* @param federatedOutput The federated output type associated with the Hop
125118
*/
126-
public void printFedPlanTree(FedPlan rootFedPlan) {
127-
Set<FedPlan> visited = new HashSet<>();
128-
printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
119+
public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
120+
hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune();
129121
}
130122

131123
/**
132-
* Helper method to recursively print the FedPlan tree.
133-
*
134-
* @param plan The current FedPlan to print
135-
* @param visited Set to keep track of visited FedPlans (prevents cycles)
136-
* @param depth The current depth level for indentation
137-
* @param isLast Whether this node is the last child of its parent
124+
* Represents common properties and costs associated with a Hop.
125+
* This class holds a reference to the Hop and tracks its execution and network transfer costs.
138126
*/
139-
private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> visited, int depth, boolean isLast) {
140-
if (plan == null || visited.contains(plan)) {
141-
return;
142-
}
143-
144-
visited.add(plan);
145-
146-
Hop hop = plan.getHopRef();
147-
StringBuilder sb = new StringBuilder();
148-
149-
// Add FedPlan information
150-
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
151-
.append(plan.getHopRef().getOpString())
152-
.append(" [")
153-
.append(plan.getFedOutType())
154-
.append("]");
155-
156-
StringBuilder childs = new StringBuilder();
157-
childs.append(" (");
158-
boolean childAdded = false;
159-
for( Hop input : hop.getInput()){
160-
childs.append(childAdded?",":"");
161-
childs.append(input.getHopID());
162-
childAdded = true;
163-
}
164-
childs.append(")");
165-
if( childAdded )
166-
sb.append(childs.toString());
167-
168-
169-
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
170-
plan.getTotalCost(),
171-
plan.getSelfCost(),
172-
plan.getNetTransferCost()));
173-
174-
// Add matrix characteristics
175-
sb.append(" [")
176-
.append(hop.getDim1()).append(", ")
177-
.append(hop.getDim2()).append(", ")
178-
.append(hop.getBlocksize()).append(", ")
179-
.append(hop.getNnz());
180-
181-
if (hop.getUpdateType().isInPlace()) {
182-
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
183-
}
184-
sb.append("]");
185-
186-
// Add memory estimates
187-
sb.append(" [")
188-
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
189-
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
190-
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
191-
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
192-
193-
// Add reblock and checkpoint requirements
194-
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
195-
sb.append(" [rblk, chkpt]");
196-
} else if (hop.requiresReblock()) {
197-
sb.append(" [rblk]");
198-
} else if (hop.requiresCheckpoint()) {
199-
sb.append(" [chkpt]");
200-
}
201-
202-
// Add execution type
203-
if (hop.getExecType() != null) {
204-
sb.append(", ").append(hop.getExecType());
205-
}
206-
207-
System.out.println(sb);
208-
209-
// Process child nodes
210-
List<Pair<Long, FederatedOutput>> childRefs = plan.getChildFedPlans();
211-
for (int i = 0; i < childRefs.size(); i++) {
212-
Pair<Long, FederatedOutput> childRef = childRefs.get(i);
213-
FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight());
214-
if (childVariants == null || childVariants.getFedPlanVariants().isEmpty())
215-
continue;
127+
public static class HopCommon {
128+
protected final Hop hopRef; // Reference to the associated Hop
129+
protected double selfCost; // Current execution cost (compute + memory access)
130+
protected double netTransferCost; // Network transfer cost
216131

217-
boolean isLastChild = (i == childRefs.size() - 1);
218-
for (FedPlan childPlan : childVariants.getFedPlanVariants()) {
219-
printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild);
220-
}
132+
protected HopCommon(Hop hopRef) {
133+
this.hopRef = hopRef;
134+
this.selfCost = 0;
135+
this.netTransferCost = 0;
221136
}
222137
}
223138

224139
/**
225-
* Represents a collection of federated execution plan variants for a specific Hop.
226-
* Contains cost information and references to the associated plans.
140+
* Represents a collection of federated execution plan variants for a specific Hop and FederatedOutput.
141+
* This class contains cost information and references to the associated plans.
142+
* It uses HopCommon to store common properties and costs related to the Hop.
227143
*/
228144
public static class FedPlanVariants {
229-
protected final Hop hopRef; // Reference to the associated Hop
230-
protected double selfCost; // Current execution cost (compute + memory access)
231-
protected double netTransferCost; // Network transfer cost
232-
private final FederatedOutput fedOutType; // Output type (FOUT/LOUT)
233-
protected List<FedPlan> _fedPlanVariants; // List of plan variants
145+
protected HopCommon hopCommon; // Common properties and costs for the Hop
146+
private final FederatedOutput fedOutType; // Output type (FOUT/LOUT)
147+
protected List<FedPlan> _fedPlanVariants; // List of plan variants
234148

235149
public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
236-
this.hopRef = hopRef;
150+
this.hopCommon = new HopCommon(hopRef);
237151
this.fedOutType = fedOutType;
238-
this.selfCost = 0;
239-
this.netTransferCost = 0;
240152
this._fedPlanVariants = new ArrayList<>();
241153
}
242154

243-
public int size() {return _fedPlanVariants.size();}
244155
public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);}
245156
public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;}
157+
public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
158+
159+
public void prune() {
160+
if (_fedPlanVariants.size() > 1) {
161+
// Find the FedPlan with the minimum cost
162+
FedPlan minCostPlan = _fedPlanVariants.stream()
163+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
164+
.orElse(null);
165+
166+
// Retain only the minimum cost plan
167+
_fedPlanVariants.clear();
168+
_fedPlanVariants.add(minCostPlan);
169+
}
170+
}
246171
}
247172

248173
/**
249174
* Represents a single federated execution plan with its associated costs and dependencies.
250-
* Contains:
175+
* This class contains:
251176
* 1. selfCost: Cost of current hop (compute + input/output memory access)
252177
* 2. totalCost: Cumulative cost including this plan and all child plans
253178
* 3. netTransferCost: Network transfer cost for this plan to parent plan.
179+
*
180+
* FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs.
254181
*/
255182
public static class FedPlan {
256-
private double totalCost; // Total cost including child plans
183+
private double totalCost; // Total cost including child plans
257184
private final FedPlanVariants fedPlanVariants; // Reference to variant list
258185
private final List<Pair<Long, FederatedOutput>> childFedPlans; // Child plan references
259186

@@ -264,25 +191,26 @@ public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants
264191
}
265192

266193
public void setTotalCost(double totalCost) {this.totalCost = totalCost;}
267-
public void setSelfCost(double selfCost) {fedPlanVariants.selfCost = selfCost;}
268-
public void setNetTransferCost(double netTransferCost) {fedPlanVariants.netTransferCost = netTransferCost;}
269-
270-
public Hop getHopRef() {return fedPlanVariants.hopRef;}
194+
public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;}
195+
public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;}
196+
197+
public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;}
198+
public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();}
271199
public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;}
272200
public double getTotalCost() {return totalCost;}
273-
public double getSelfCost() {return fedPlanVariants.selfCost;}
274-
private double getNetTransferCost() {return fedPlanVariants.netTransferCost;}
201+
public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;}
202+
public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;}
275203
public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;}
276204

277205
/**
278206
* Calculates the conditional network transfer cost based on output type compatibility.
279207
* Returns 0 if output types match, otherwise returns the network transfer cost.
280-
* @param parentFedOutType ?
281-
* @return ?
208+
* @param parentFedOutType The federated output type of the parent plan.
209+
* @return The conditional network transfer cost.
282210
*/
283211
public double getCondNetTransferCost(FederatedOutput parentFedOutType) {
284212
if (parentFedOutType == getFedOutType()) return 0;
285-
return fedPlanVariants.netTransferCost;
213+
return fedPlanVariants.hopCommon.netTransferCost;
286214
}
287215
}
288216
}

0 commit comments

Comments
 (0)