Skip to content

Commit 29b4d92

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3790] Rework FedPlanner memo table, cost estimator, enumerator
Closes #2147.
1 parent d3fcfb1 commit 29b4d92

File tree

7 files changed

+652
-346
lines changed

7 files changed

+652
-346
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.fedplanner;
21+
22+
import org.apache.sysds.hops.Hop;
23+
import org.apache.sysds.hops.OptimizerUtils;
24+
import org.apache.commons.lang3.tuple.Pair;
25+
import org.apache.commons.lang3.tuple.ImmutablePair;
26+
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
27+
import java.util.Comparator;
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.ArrayList;
31+
import java.util.Map;
32+
import java.util.HashSet;
33+
import java.util.Set;
34+
35+
/**
36+
* A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
37+
* This table stores and manages different execution plan variants for each Hop and fedOutType combination,
38+
* facilitating the optimization of federated execution plans.
39+
*/
40+
public class FederatedMemoTable {
41+
// Maps Hop ID and fedOutType pairs to their plan variants
42+
private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<>();
43+
44+
/**
45+
* Adds a new federated plan to the memo table.
46+
* Creates a new variant list if none exists for the given Hop and fedOutType.
47+
*
48+
* @param hop The Hop node
49+
* @param fedOutType The federated output type
50+
* @param planChilds List of child plan references
51+
* @return The newly created FedPlan
52+
*/
53+
public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, FederatedOutput>> planChilds) {
54+
long hopID = hop.getHopID();
55+
FedPlanVariants fedPlanVariantList;
56+
57+
if (contains(hopID, fedOutType)) {
58+
fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
59+
} else {
60+
fedPlanVariantList = new FedPlanVariants(hop, fedOutType);
61+
hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList);
62+
}
63+
64+
FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
65+
fedPlanVariantList.addFedPlan(newPlan);
66+
67+
return newPlan;
68+
}
69+
70+
/**
71+
* Retrieves the minimum cost child plan considering the parent's output type.
72+
* The cost is calculated using getParentViewCost to account for potential type mismatches.
73+
*
74+
* @param childHopID ?
75+
* @param childFedOutType ?
76+
* @return ?
77+
*/
78+
public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput childFedOutType) {
79+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(childHopID, childFedOutType));
80+
return fedPlanVariantList._fedPlanVariants.stream()
81+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
82+
.orElse(null);
83+
}
84+
85+
public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) {
86+
return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
87+
}
88+
89+
/**
90+
* Checks if the memo table contains an entry for a given Hop and fedOutType.
91+
*
92+
* @param hopID The Hop ID.
93+
* @param fedOutType The associated fedOutType.
94+
* @return True if the entry exists, false otherwise.
95+
*/
96+
public boolean contains(long hopID, FederatedOutput fedOutType) {
97+
return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType));
98+
}
99+
100+
/**
101+
* Prunes all entries in the memo table, retaining only the minimum-cost
102+
* FedPlan for each entry.
103+
*/
104+
public void pruneMemoTable() {
105+
for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> entry : hopMemoTable.entrySet()) {
106+
List<FedPlan> fedPlanList = entry.getValue().getFedPlanVariants();
107+
if (fedPlanList.size() > 1) {
108+
// Find the FedPlan with the minimum cost
109+
FedPlan minCostPlan = fedPlanList.stream()
110+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
111+
.orElse(null);
112+
113+
// Retain only the minimum cost plan
114+
fedPlanList.clear();
115+
fedPlanList.add(minCostPlan);
116+
}
117+
}
118+
}
119+
120+
/**
121+
* Recursively prints a tree representation of the DAG starting from the given root FedPlan.
122+
* Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
123+
*
124+
* @param rootFedPlan The starting point FedPlan to print
125+
*/
126+
public void printFedPlanTree(FedPlan rootFedPlan) {
127+
Set<FedPlan> visited = new HashSet<>();
128+
printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
129+
}
130+
131+
/**
132+
* Helper method to recursively print the FedPlan tree.
133+
*
134+
* @param plan The current FedPlan to print
135+
* @param visited Set to keep track of visited FedPlans (prevents cycles)
136+
* @param depth The current depth level for indentation
137+
* @param isLast Whether this node is the last child of its parent
138+
*/
139+
private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> visited, int depth, boolean isLast) {
140+
if (plan == null || visited.contains(plan)) {
141+
return;
142+
}
143+
144+
visited.add(plan);
145+
146+
Hop hop = plan.getHopRef();
147+
StringBuilder sb = new StringBuilder();
148+
149+
// Add FedPlan information
150+
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
151+
.append(plan.getHopRef().getOpString())
152+
.append(" [")
153+
.append(plan.getFedOutType())
154+
.append("]");
155+
156+
StringBuilder childs = new StringBuilder();
157+
childs.append(" (");
158+
boolean childAdded = false;
159+
for( Hop input : hop.getInput()){
160+
childs.append(childAdded?",":"");
161+
childs.append(input.getHopID());
162+
childAdded = true;
163+
}
164+
childs.append(")");
165+
if( childAdded )
166+
sb.append(childs.toString());
167+
168+
169+
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
170+
plan.getTotalCost(),
171+
plan.getSelfCost(),
172+
plan.getNetTransferCost()));
173+
174+
// Add matrix characteristics
175+
sb.append(" [")
176+
.append(hop.getDim1()).append(", ")
177+
.append(hop.getDim2()).append(", ")
178+
.append(hop.getBlocksize()).append(", ")
179+
.append(hop.getNnz());
180+
181+
if (hop.getUpdateType().isInPlace()) {
182+
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
183+
}
184+
sb.append("]");
185+
186+
// Add memory estimates
187+
sb.append(" [")
188+
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
189+
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
190+
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
191+
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
192+
193+
// Add reblock and checkpoint requirements
194+
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
195+
sb.append(" [rblk, chkpt]");
196+
} else if (hop.requiresReblock()) {
197+
sb.append(" [rblk]");
198+
} else if (hop.requiresCheckpoint()) {
199+
sb.append(" [chkpt]");
200+
}
201+
202+
// Add execution type
203+
if (hop.getExecType() != null) {
204+
sb.append(", ").append(hop.getExecType());
205+
}
206+
207+
System.out.println(sb);
208+
209+
// Process child nodes
210+
List<Pair<Long, FederatedOutput>> childRefs = plan.getChildFedPlans();
211+
for (int i = 0; i < childRefs.size(); i++) {
212+
Pair<Long, FederatedOutput> childRef = childRefs.get(i);
213+
FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight());
214+
if (childVariants == null || childVariants.getFedPlanVariants().isEmpty())
215+
continue;
216+
217+
boolean isLastChild = (i == childRefs.size() - 1);
218+
for (FedPlan childPlan : childVariants.getFedPlanVariants()) {
219+
printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild);
220+
}
221+
}
222+
}
223+
224+
/**
225+
* Represents a collection of federated execution plan variants for a specific Hop.
226+
* Contains cost information and references to the associated plans.
227+
*/
228+
public static class FedPlanVariants {
229+
protected final Hop hopRef; // Reference to the associated Hop
230+
protected double selfCost; // Current execution cost (compute + memory access)
231+
protected double netTransferCost; // Network transfer cost
232+
private final FederatedOutput fedOutType; // Output type (FOUT/LOUT)
233+
protected List<FedPlan> _fedPlanVariants; // List of plan variants
234+
235+
public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
236+
this.hopRef = hopRef;
237+
this.fedOutType = fedOutType;
238+
this.selfCost = 0;
239+
this.netTransferCost = 0;
240+
this._fedPlanVariants = new ArrayList<>();
241+
}
242+
243+
public int size() {return _fedPlanVariants.size();}
244+
public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);}
245+
public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;}
246+
}
247+
248+
/**
249+
* Represents a single federated execution plan with its associated costs and dependencies.
250+
* Contains:
251+
* 1. selfCost: Cost of current hop (compute + input/output memory access)
252+
* 2. totalCost: Cumulative cost including this plan and all child plans
253+
* 3. netTransferCost: Network transfer cost for this plan to parent plan.
254+
*/
255+
public static class FedPlan {
256+
private double totalCost; // Total cost including child plans
257+
private final FedPlanVariants fedPlanVariants; // Reference to variant list
258+
private final List<Pair<Long, FederatedOutput>> childFedPlans; // Child plan references
259+
260+
public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants fedPlanVariants) {
261+
this.totalCost = 0;
262+
this.childFedPlans = childFedPlans;
263+
this.fedPlanVariants = fedPlanVariants;
264+
}
265+
266+
public void setTotalCost(double totalCost) {this.totalCost = totalCost;}
267+
public void setSelfCost(double selfCost) {fedPlanVariants.selfCost = selfCost;}
268+
public void setNetTransferCost(double netTransferCost) {fedPlanVariants.netTransferCost = netTransferCost;}
269+
270+
public Hop getHopRef() {return fedPlanVariants.hopRef;}
271+
public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;}
272+
public double getTotalCost() {return totalCost;}
273+
public double getSelfCost() {return fedPlanVariants.selfCost;}
274+
private double getNetTransferCost() {return fedPlanVariants.netTransferCost;}
275+
public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;}
276+
277+
/**
278+
* Calculates the conditional network transfer cost based on output type compatibility.
279+
* Returns 0 if output types match, otherwise returns the network transfer cost.
280+
* @param parentFedOutType ?
281+
* @return ?
282+
*/
283+
public double getCondNetTransferCost(FederatedOutput parentFedOutType) {
284+
if (parentFedOutType == getFedOutType()) return 0;
285+
return fedPlanVariants.netTransferCost;
286+
}
287+
}
288+
}

0 commit comments

Comments
 (0)