Skip to content

Commit fd9479d

Browse files
committed
program level fed planer
1 parent 16a8d00 commit fd9479d

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
import org.apache.commons.lang3.tuple.Pair;
3030
import org.apache.commons.lang3.tuple.ImmutablePair;
31+
import org.apache.sysds.common.Types;
32+
import org.apache.sysds.hops.DataOp;
3133
import org.apache.sysds.hops.Hop;
3234
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
3335
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
@@ -50,8 +52,11 @@
5052
*/
5153
public class FederatedPlanCostEnumerator {
5254
public static void enumerateProgram(DMLProgram prog) {
55+
FederatedMemoTable memoTable = new FederatedMemoTable();
56+
Map<String, Long> transTable = new HashMap<>();
57+
5358
for(StatementBlock sb : prog.getStatementBlocks())
54-
enumerateStatementBlock(sb);
59+
enumerateStatementBlock(sb, memoTable, transTable);
5560
}
5661

5762
/**
@@ -61,7 +66,7 @@ public static void enumerateProgram(DMLProgram prog) {
6166
*
6267
* @param sb The statement block to enumerate.
6368
*/
64-
public static void enumerateStatementBlock(StatementBlock sb) {
69+
public static void enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map<String, Long> transTable) {
6570
// While enumerating the program, recursively determine the optimal FedPlan and MemoTable
6671
// for each statement block and statement.
6772
// 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks?
@@ -75,12 +80,12 @@ public static void enumerateStatementBlock(StatementBlock sb) {
7580
IfStatementBlock isb = (IfStatementBlock) sb;
7681
IfStatement istmt = (IfStatement)isb.getStatement(0);
7782

78-
enumerateFederatedPlanCost(isb.getPredicateHops());
83+
enumerateHopDAG(isb.getPredicateHops(), memoTable, transTable);
7984

8085
for (StatementBlock csb : istmt.getIfBody())
81-
enumerateStatementBlock(csb);
86+
enumerateStatementBlock(csb, memoTable, transTable);
8287
for (StatementBlock csb : istmt.getElseBody())
83-
enumerateStatementBlock(csb);
88+
enumerateStatementBlock(csb, memoTable, transTable);
8489

8590
// Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5)
8691
// Todo: 2. Merge predFedPlans
@@ -89,24 +94,24 @@ public static void enumerateStatementBlock(StatementBlock sb) {
8994

9095
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
9196

92-
enumerateFederatedPlanCost(fsb.getFromHops());
93-
enumerateFederatedPlanCost(fsb.getToHops());
94-
enumerateFederatedPlanCost(fsb.getIncrementHops());
97+
enumerateHopDAG(fsb.getFromHops(), memoTable, transTable);
98+
enumerateHopDAG(fsb.getToHops(), memoTable, transTable);
99+
enumerateHopDAG(fsb.getIncrementHops(), memoTable, transTable);
95100

96101
for (StatementBlock csb : fstmt.getBody())
97-
enumerateStatementBlock(csb);
102+
enumerateStatementBlock(csb, memoTable, transTable);
98103

99104
// Todo: 1. get(predict) # of Iterations
100105
// Todo: 2. apply iteration weight to csbFedPlans
101106
// Todo: 3. Merge csbFedPlans and predFedPlans
102107
} else if (sb instanceof WhileStatementBlock) {
103108
WhileStatementBlock wsb = (WhileStatementBlock) sb;
104109
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
105-
enumerateFederatedPlanCost(wsb.getPredicateHops());
110+
enumerateHopDAG(wsb.getPredicateHops(), memoTable, transTable);
106111

107112
ArrayList<FedPlan> csbFedPlans = new ArrayList<>();
108113
for (StatementBlock csb : wstmt.getBody())
109-
enumerateStatementBlock(csb);
114+
enumerateStatementBlock(csb, memoTable, transTable);
110115

111116
// Todo: 1. get(predict) # of Iterations
112117
// Todo: 2. apply iteration weight to csbFedPlans
@@ -115,13 +120,13 @@ public static void enumerateStatementBlock(StatementBlock sb) {
115120
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
116121
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
117122
for (StatementBlock csb : fstmt.getBody())
118-
enumerateStatementBlock(csb);
123+
enumerateStatementBlock(csb, memoTable, transTable);
119124

120125
// Todo: 1. Merge csbFedPlans
121126
} else { //generic (last-level)
122127
if( sb.getHops() != null )
123128
for( Hop c : sb.getHops() )
124-
enumerateFederatedPlanCost(c);
129+
enumerateHopDAG(c, memoTable, transTable);
125130
}
126131
}
127132

@@ -133,12 +138,11 @@ public static void enumerateStatementBlock(StatementBlock sb) {
133138
* @param rootHop The root Hop node from which to start the plan enumeration.
134139
* @return The optimal FedPlan with the minimum cost for the entire DAG.
135140
*/
136-
public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {
141+
public static FedPlan enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map<String, Long> transTable) {
137142
// Create new memo table to store all plan variants
138-
FederatedMemoTable memoTable = new FederatedMemoTable();
139143

140144
// Recursively enumerate all possible plans
141-
enumerateFederatedPlanCost(rootHop, memoTable);
145+
enumerateHop(rootHop, memoTable, transTable);
142146

143147
// Return the minimum cost plan for the root node
144148
FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
@@ -167,14 +171,29 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {
167171
* @param hop ?
168172
* @param memoTable ?
169173
*/
170-
private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) {
174+
private static void enumerateHop(Hop hop, FederatedMemoTable memoTable, Map<String, Long> transTable) {
171175
int numInputs = hop.getInput().size();
172176

173177
// Process all input nodes first if not already in memo table
174178
for (Hop inputHop : hop.getInput()) {
175179
if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT)
176180
&& !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) {
177-
enumerateFederatedPlanCost(inputHop, memoTable);
181+
enumerateHop(inputHop, memoTable, transTable);
182+
}
183+
}
184+
185+
if (hop instanceof DataOp
186+
&& ((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE
187+
&& !(hop.getName().equals("__pred"))){
188+
transTable.put(hop.getName(), hop.getHopID());
189+
}
190+
191+
if (hop instanceof DataOp
192+
&& !(hop.getName().equals("__pred"))){
193+
if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTWRITE){
194+
transTable.put(hop.getName(), hop.getHopID());
195+
} else if (((DataOp)hop).getOp()== Types.OpOpData.TRANSIENTREAD){
196+
long rWriteHopID = transTable.get(hop.getName());
178197
}
179198
}
180199

src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
package org.apache.sysds.test.component.federated;
2121

2222
import java.io.IOException;
23+
import java.util.Arrays;
24+
import java.util.Collection;
2325
import java.util.HashMap;
2426

27+
import org.apache.sysds.common.Types;
2528
import org.apache.sysds.hops.Hop;
29+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
30+
import org.apache.sysds.test.TestUtils;
31+
import org.apache.sysds.test.functions.federated.algorithms.FederatedL2SVMTest;
2632
import org.junit.Assert;
2733
import org.junit.Test;
2834
import org.apache.sysds.api.DMLScript;
@@ -67,9 +73,6 @@ public void setUp() {}
6773
@Test
6874
public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); }
6975

70-
@Test
71-
public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest4.dml"); }
72-
7376
// Todo: Need to write test scripts for the federated version
7477
private void runTest( String scriptFilename ) {
7578
int index = scriptFilename.lastIndexOf(".dml");

0 commit comments

Comments
 (0)