Skip to content

Commit 283359a

Browse files
min-gukmboehm7
authored andcommitted
[SYSTEMDS-3790] Improvements of cost-based federated planner
Closes #2273, closes #2294.
1 parent 2c737bc commit 283359a

35 files changed

+4830
-1163
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public AFederatedPlanner getPlanner() {
3434
case COMPILE_FED_HEURISTIC:
3535
return new FederatedPlannerFedHeuristic();
3636
case COMPILE_COST_BASED:
37+
return new FederatedPlannerFedCostBased();
3738
case NONE:
3839
case RUNTIME:
3940
default:
@@ -130,4 +131,10 @@ public boolean isColType() {
130131
return (this == COL || this == COL_T);
131132
}
132133
}
134+
135+
public enum Privacy {
136+
PRIVATE,
137+
PRIVATE_AGGREGATE,
138+
PUBLIC
139+
}
133140
}

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.commons.lang3.tuple.Pair;
2929
import org.apache.commons.lang3.tuple.ImmutablePair;
3030
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
31+
import org.apache.sysds.common.Types.ExecType;
3132

3233
/**
3334
* A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
@@ -46,13 +47,19 @@ public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> fedPlanPai
4647
return hopMemoTable.get(fedPlanPair);
4748
}
4849

49-
public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) {
50-
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
50+
public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput federatedOutput) {
51+
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput));
52+
if (fedPlanVariantList == null || fedPlanVariantList.isEmpty()) {
53+
return null;
54+
}
5155
return fedPlanVariantList._fedPlanVariants.get(0);
5256
}
5357

5458
public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput> fedPlanPair) {
5559
FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair);
60+
if (fedPlanVariantList == null || fedPlanVariantList.isEmpty()) {
61+
return null;
62+
}
5663
return fedPlanVariantList._fedPlanVariants.get(0);
5764
}
5865

@@ -61,13 +68,17 @@ public boolean contains(long hopID, FederatedOutput fedOutType) {
6168
}
6269

6370
/**
64-
* Represents a single federated execution plan with its associated costs and dependencies.
71+
* Represents a single federated execution plan with its associated costs and
72+
* dependencies.
6573
* This class contains:
66-
* 1. selfCost: Cost of the current hop (computation + input/output memory access).
67-
* 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost.
74+
* 1. selfCost: Cost of the current hop (computation + input/output memory
75+
* access).
76+
* 2. cumulativeCost: Total cost including this plan's selfCost and all child
77+
* plans' cumulativeCost.
6878
* 3. forwardingCost: Network transfer cost for this plan to the parent plan.
6979
*
70-
* FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs.
80+
* FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage
81+
* common properties and costs.
7182
*/
7283
public static class FedPlan {
7384
private double cumulativeCost; // Total cost = sum of selfCost + cumulativeCost of child plans
@@ -84,10 +95,31 @@ public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List<Pair
8495
public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();}
8596
public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();}
8697
public double getCumulativeCost() {return cumulativeCost;}
98+
public double getCumulativeCostPerParents() {
99+
double cumulativeCostPerParents = cumulativeCost;
100+
int numOfParents = fedPlanVariants.hopCommon.getNumOfParents();
101+
if (numOfParents >= 2){
102+
cumulativeCostPerParents /= numOfParents;
103+
}
104+
return cumulativeCostPerParents;
105+
}
87106
public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();}
88107
public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();}
89-
public double getWeight() {return fedPlanVariants.hopCommon.getWeight();}
108+
public double getForwardingCostPerParents() {
109+
double forwardingCostPerParents = fedPlanVariants.hopCommon.getForwardingCost();
110+
int numOfParents = fedPlanVariants.hopCommon.getNumOfParents();
111+
if (numOfParents >= 2){
112+
forwardingCostPerParents /= numOfParents;
113+
}
114+
return forwardingCostPerParents;
115+
}
116+
public double getComputeWeight() {return fedPlanVariants.hopCommon.getComputeWeight();}
117+
public double getNetworkWeight() {return fedPlanVariants.hopCommon.getNetworkWeight();}
118+
public double getChildForwardingWeight(List<Pair<Long, Double>> childLoopContext) {return fedPlanVariants.hopCommon.getChildForwardingWeight(childLoopContext);}
119+
public List<Pair<Long, Double>> getLoopContext() {return fedPlanVariants.hopCommon.getLoopContext();}
90120
public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;}
121+
public void setFederatedOutput(FederatedOutput fedOutType) {fedPlanVariants.hopCommon.hopRef.setFederatedOutput(fedOutType);}
122+
public void setForcedExecType(ExecType execType) {fedPlanVariants.hopCommon.hopRef.setForcedExecType(execType);}
91123
}
92124

93125
/**
@@ -111,8 +143,8 @@ public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) {
111143
public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;}
112144
public FederatedOutput getFedOutType() {return fedOutType;}
113145

114-
public void pruneFedPlans() {
115-
if (_fedPlanVariants.size() > 1) {
146+
public boolean pruneFedPlans() {
147+
if (!_fedPlanVariants.isEmpty()) {
116148
// Find the FedPlan with the minimum cumulative cost
117149
FedPlan minCostPlan = _fedPlanVariants.stream()
118150
.min(Comparator.comparingDouble(FedPlan::getCumulativeCost))
@@ -121,33 +153,63 @@ public void pruneFedPlans() {
121153
// Retain only the minimum cost plan
122154
_fedPlanVariants.clear();
123155
_fedPlanVariants.add(minCostPlan);
156+
return true;
124157
}
158+
return false;
125159
}
126160
}
127161

128162
/**
129163
* Represents common properties and costs associated with a Hop.
130164
* This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs.
165+
* It also maintains the loop context information to properly calculate forwarding costs within loops.
131166
*/
132167
public static class HopCommon {
133168
protected final Hop hopRef; // Reference to the associated Hop
134169
protected double selfCost; // Cost of the hop's computation and memory access
135170
protected double forwardingCost; // Cost of forwarding the hop's output to its parent
136-
protected double weight; // Weight used to calculate cost based on hop execution frequency
171+
protected int numOfParents;
172+
protected double computeWeight; // Weight used to calculate cost based on hop execution frequency
173+
protected double networkWeight; // Weight used to calculate cost based on hop execution frequency
174+
protected List<Pair<Long, Double>> loopContext; // Loop context in which this hop exists
137175

138-
public HopCommon(Hop hopRef, double weight) {
176+
public HopCommon(Hop hopRef, double computeWeight, double networkWeight, int numOfParents, List<Pair<Long, Double>> loopContext) {
139177
this.hopRef = hopRef;
140178
this.selfCost = 0;
141179
this.forwardingCost = 0;
142-
this.weight = weight;
180+
this.numOfParents = numOfParents;
181+
this.computeWeight = computeWeight;
182+
this.networkWeight = networkWeight;
183+
this.loopContext = loopContext != null ? new ArrayList<>(loopContext) : new ArrayList<>();
143184
}
144185

145186
public Hop getHopRef() {return hopRef;}
146187
public double getSelfCost() {return selfCost;}
147188
public double getForwardingCost() {return forwardingCost;}
148-
public double getWeight() {return weight;}
189+
public double getComputeWeight() {return computeWeight;}
190+
public double getNetworkWeight() {return networkWeight;}
191+
public int getNumOfParents() {return numOfParents;}
192+
public List<Pair<Long, Double>> getLoopContext() {return loopContext;}
149193

150194
protected void setSelfCost(double selfCost) {this.selfCost = selfCost;}
151195
protected void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;}
196+
protected void setNumOfParentHops(int numOfParentHops) {this.numOfParents = numOfParentHops;}
197+
198+
public double getChildForwardingWeight(List<Pair<Long, Double>> childLoopContext) {
199+
if (loopContext.isEmpty()) {
200+
return networkWeight;
201+
}
202+
203+
double forwardingWeight = this.networkWeight;
204+
205+
for (int i = 0; i < loopContext.size(); i++) {
206+
if (i >= childLoopContext.size() || loopContext.get(i).getLeft() != childLoopContext.get(i).getLeft()) {
207+
forwardingWeight /=loopContext.get(i).getRight();
208+
}
209+
}
210+
211+
// Check if the innermost loops are the same
212+
return forwardingWeight;
213+
}
152214
}
153215
}

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boo
169169
plan.getCumulativeCost(),
170170
plan.getSelfCost(),
171171
plan.getForwardingCost(),
172-
plan.getWeight()));
172+
plan.getComputeWeight()));
173173

174174
// Add matrix characteristics
175175
sb.append(" [")

0 commit comments

Comments
 (0)