Skip to content

Commit cd7e229

Browse files
committed
Enumerator for an optimal federated plan at the program level
1 parent 718a180 commit cd7e229

File tree

9 files changed

+245
-35
lines changed

9 files changed

+245
-35
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.sysds.hops.fedplanner;
2121

2222
import org.apache.sysds.hops.Hop;
23-
import org.apache.sysds.hops.OptimizerUtils;
2423
import org.apache.commons.lang3.tuple.Pair;
2524
import org.apache.commons.lang3.tuple.ImmutablePair;
2625
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
@@ -29,8 +28,6 @@
2928
import java.util.List;
3029
import java.util.ArrayList;
3130
import java.util.Map;
32-
import java.util.HashSet;
33-
import java.util.Set;
3431

3532
/**
3633
* A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
@@ -127,12 +124,12 @@ public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
127124
public static class HopCommon {
128125
protected final Hop hopRef; // Reference to the associated Hop
129126
protected double selfCost; // Current execution cost (compute + memory access)
130-
protected double netTransferCost; // Network transfer cost
127+
protected double forwardingCost; // Network transfer cost
131128

132129
protected HopCommon(Hop hopRef) {
133130
this.hopRef = hopRef;
134131
this.selfCost = 0;
135-
this.netTransferCost = 0;
132+
this.forwardingCost = 0;
136133
}
137134
}
138135

@@ -175,7 +172,7 @@ public void prune() {
175172
* This class contains:
176173
* 1. selfCost: Cost of current hop (compute + input/output memory access)
177174
* 2. totalCost: Cumulative cost including this plan and all child plans
178-
* 3. netTransferCost: Network transfer cost for this plan to parent plan.
175+
* 3. forwardingCost: Network transfer cost for this plan to parent plan.
179176
*
180177
* FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs.
181178
*/
@@ -192,14 +189,15 @@ public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants
192189

193190
public void setTotalCost(double totalCost) {this.totalCost = totalCost;}
194191
public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;}
195-
public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;}
196-
192+
public void setForwardingCost(double forwardingCost) {fedPlanVariants.hopCommon.forwardingCost = forwardingCost;}
193+
public void applyIterationWeight(int iteration) {totalCost *= iteration;}
194+
197195
public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;}
198196
public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();}
199197
public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;}
200198
public double getTotalCost() {return totalCost;}
201199
public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;}
202-
public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;}
200+
public double setForwardingCost() {return fedPlanVariants.hopCommon.forwardingCost;}
203201
public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;}
204202

205203
/**
@@ -208,9 +206,9 @@ public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants
208206
* @param parentFedOutType The federated output type of the parent plan.
209207
* @return The conditional network transfer cost.
210208
*/
211-
public double getCondNetTransferCost(FederatedOutput parentFedOutType) {
209+
public double getCondForwardingCost(FederatedOutput parentFedOutType) {
212210
if (parentFedOutType == getFedOutType()) return 0;
213-
return fedPlanVariants.hopCommon.netTransferCost;
211+
return fedPlanVariants.hopCommon.forwardingCost;
214212
}
215213
}
216214
}

src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, F
6767
sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
6868
plan.getTotalCost(),
6969
plan.getSelfCost(),
70-
plan.getNetTransferCost()));
70+
plan.setForwardingCost()));
7171

7272
// Add matrix characteristics
7373
sb.append(" [")

src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@
3131
import org.apache.sysds.hops.Hop;
3232
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
3333
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
34+
import org.apache.sysds.parser.DMLProgram;
35+
import org.apache.sysds.parser.ForStatement;
36+
import org.apache.sysds.parser.ForStatementBlock;
37+
import org.apache.sysds.parser.FunctionStatement;
38+
import org.apache.sysds.parser.FunctionStatementBlock;
39+
import org.apache.sysds.parser.IfStatement;
40+
import org.apache.sysds.parser.IfStatementBlock;
41+
import org.apache.sysds.parser.StatementBlock;
42+
import org.apache.sysds.parser.WhileStatement;
43+
import org.apache.sysds.parser.WhileStatementBlock;
3444
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
3545

3646
/**
@@ -39,16 +49,91 @@
3949
* to compute their costs.
4050
*/
4151
public class FederatedPlanCostEnumerator {
52+
public static void enumerateProgram(DMLProgram prog) {
53+
for(StatementBlock sb : prog.getStatementBlocks())
54+
enumerateStatementBlock(sb);
55+
}
56+
57+
/**
58+
* Recursively enumerates federated execution plans for a given statement block.
59+
* This method processes each type of statement block (If, For, While, Function, and generic)
60+
* to determine the optimal federated plan.
61+
*
62+
* @param sb The statement block to enumerate.
63+
*/
64+
public static void enumerateStatementBlock(StatementBlock sb) {
65+
// While enumerating the program, recursively determine the optimal FedPlan and MemoTable
66+
// for each statement block and statement.
67+
// 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks?
68+
// 1) Is it determined using the same dynamic programming approach, or simply by summing the minimal plans?
69+
// 2. Is there a need to share the MemoTable? Are there data/hop dependencies between statements?
70+
// 3. How to predict the number of iterations for For and While loops?
71+
// 1) If from/to/increment are constants: Calculations can be done at compile time.
72+
// 2) If they are variables: Use default values at compile time, adjust at runtime, or predict using ML models.
73+
74+
if (sb instanceof IfStatementBlock) {
75+
IfStatementBlock isb = (IfStatementBlock) sb;
76+
IfStatement istmt = (IfStatement)isb.getStatement(0);
77+
78+
enumerateFederatedPlanCost(isb.getPredicateHops());
79+
80+
for (StatementBlock csb : istmt.getIfBody())
81+
enumerateStatementBlock(csb);
82+
for (StatementBlock csb : istmt.getElseBody())
83+
enumerateStatementBlock(csb);
84+
85+
// Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5)
86+
// Todo: 2. Merge predFedPlans
87+
} else if (sb instanceof ForStatementBlock) { //incl parfor
88+
ForStatementBlock fsb = (ForStatementBlock) sb;
89+
90+
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
91+
92+
enumerateFederatedPlanCost(fsb.getFromHops());
93+
enumerateFederatedPlanCost(fsb.getToHops());
94+
enumerateFederatedPlanCost(fsb.getIncrementHops());
95+
96+
for (StatementBlock csb : fstmt.getBody())
97+
enumerateStatementBlock(csb);
98+
99+
// Todo: 1. get(predict) # of Iterations
100+
// Todo: 2. apply iteration weight to csbFedPlans
101+
// Todo: 3. Merge csbFedPlans and predFedPlans
102+
} else if (sb instanceof WhileStatementBlock) {
103+
WhileStatementBlock wsb = (WhileStatementBlock) sb;
104+
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
105+
enumerateFederatedPlanCost(wsb.getPredicateHops());
106+
107+
ArrayList<FedPlan> csbFedPlans = new ArrayList<>();
108+
for (StatementBlock csb : wstmt.getBody())
109+
enumerateStatementBlock(csb);
110+
111+
// Todo: 1. get(predict) # of Iterations
112+
// Todo: 2. apply iteration weight to csbFedPlans
113+
// Todo: 3. Merge csbFedPlans and predFedPlans
114+
} else if (sb instanceof FunctionStatementBlock) {
115+
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
116+
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
117+
for (StatementBlock csb : fstmt.getBody())
118+
enumerateStatementBlock(csb);
119+
120+
// Todo: 1. Merge csbFedPlans
121+
} else { //generic (last-level)
122+
if( sb.getHops() != null )
123+
for( Hop c : sb.getHops() )
124+
enumerateFederatedPlanCost(c);
125+
}
126+
}
127+
42128
/**
43129
* Entry point for federated plan enumeration. This method creates a memo table
44130
* and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG).
45131
* It also resolves conflicts where FedPlans have different FederatedOutput types.
46132
*
47133
* @param rootHop The root Hop node from which to start the plan enumeration.
48-
* @param printTree A boolean flag indicating whether to print the federated plan tree.
49134
* @return The optimal FedPlan with the minimum cost for the entire DAG.
50135
*/
51-
public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) {
136+
public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {
52137
// Create new memo table to store all plan variants
53138
FederatedMemoTable memoTable = new FederatedMemoTable();
54139

@@ -61,8 +146,8 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree)
61146
// Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types
62147
double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable);
63148

64-
// Optionally print the federated plan tree if requested
65-
if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost);
149+
// Print the federated plan tree if requested
150+
FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost);
66151

67152
return optimalPlan;
68153
}

src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTa
6161
totalCost = computeCurrentCost(currentHop);
6262
currentPlan.setSelfCost(totalCost);
6363
// Calculate potential network transfer cost if federation type changes
64-
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
64+
currentPlan.setForwardingCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
6565
} else {
6666
totalCost = currentPlan.getSelfCost();
6767
}
@@ -74,7 +74,7 @@ public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTa
7474
FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair);
7575

7676
// Add child plan cost (includes network transfer cost if federation types differ)
77-
totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType());
77+
totalCost += planRef.getTotalCost() + planRef.getCondForwardingCost(currentPlan.getFedOutType());
7878
}
7979

8080
// Step 3: Set final cumulative cost including current node
@@ -111,8 +111,8 @@ public static LinkedHashMap<FedPlan, Boolean> resolveConflictFedPlan(FederatedMe
111111

112112
// Flags to check if the plan involves network transfer
113113
// Network transfer cost is calculated only once, even if it occurs multiple times
114-
boolean isLOutNetTransfer = false;
115-
boolean isFOutNetTransfer = false;
114+
boolean isLOutForwarding = false;
115+
boolean isFOutForwarding = false;
116116

117117
// Determine the optimal federated output type based on the calculated costs
118118
FederatedOutput optimalFedOutType;
@@ -143,35 +143,35 @@ public static LinkedHashMap<FedPlan, Boolean> resolveConflictFedPlan(FederatedMe
143143
if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) {
144144
// (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred
145145
// (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later
146-
isFOutNetTransfer = true;
146+
isFOutForwarding = true;
147147
} else {
148148
// Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred
149149
// (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later
150-
isLOutNetTransfer = true;
151-
lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
150+
isLOutForwarding = true;
151+
lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost();
152152

153153
// (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it
154-
fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
154+
fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost();
155155
}
156156
} else {
157157
lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost();
158158

159159
if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) {
160-
isLOutNetTransfer = true;
160+
isLOutForwarding = true;
161161
} else {
162-
isFOutNetTransfer = true;
163-
lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
164-
fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost();
162+
isFOutForwarding = true;
163+
lOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost();
164+
fOutAdditionalCost -= confilctLOutFedPlan.setForwardingCost();
165165
}
166166
}
167167
}
168168

169169
// Add network transfer costs if applicable
170-
if (isLOutNetTransfer) {
171-
lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost();
170+
if (isLOutForwarding) {
171+
lOutAdditionalCost += confilctLOutFedPlan.setForwardingCost();
172172
}
173-
if (isFOutNetTransfer) {
174-
fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost();
173+
if (isFOutForwarding) {
174+
fOutAdditionalCost += confilctFOutFedPlan.setForwardingCost();
175175
}
176176

177177
// Determine the optimal federated output type based on the calculated costs

src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
4545

4646
@Override
4747
public void setUp() {}
48-
48+
4949
@Test
5050
public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); }
5151

@@ -55,6 +55,21 @@ public void setUp() {}
5555
@Test
5656
public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); }
5757

58+
@Test
59+
public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); }
60+
61+
@Test
62+
public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); }
63+
64+
@Test
65+
public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); }
66+
67+
@Test
68+
public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); }
69+
70+
@Test
71+
public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest4.dml"); }
72+
5873
// Todo: Need to write test scripts for the federated version
5974
private void runTest( String scriptFilename ) {
6075
int index = scriptFilename.lastIndexOf(".dml");
@@ -80,8 +95,7 @@ private void runTest( String scriptFilename ) {
8095
dmlt.rewriteHopsDAG(prog);
8196
dmlt.constructLops(prog);
8297

83-
Hop hops = prog.getStatementBlocks().get(0).getHops().get(0);
84-
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
98+
FederatedPlanCostEnumerator.enumerateProgram(prog);
8599
}
86100
catch (IOException e) {
87101
e.printStackTrace();
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
a = matrix(7,10,10);
23+
if (sum(a) > 0.5)
24+
b = a * 2;
25+
else
26+
b = a * 3;
27+
c = sqrt(b);
28+
print(sum(c));
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
for( i in 1:10 )
23+
{
24+
b = i + 1;
25+
print(b);
26+
}

0 commit comments

Comments
 (0)