2828
2929import org .apache .commons .lang3 .tuple .Pair ;
3030import org .apache .commons .lang3 .tuple .ImmutablePair ;
31+ import org .apache .sysds .common .Types ;
32+ import org .apache .sysds .hops .DataOp ;
3133import org .apache .sysds .hops .Hop ;
3234import org .apache .sysds .hops .fedplanner .FederatedMemoTable .FedPlan ;
3335import org .apache .sysds .hops .fedplanner .FederatedMemoTable .FedPlanVariants ;
5052 */
5153public class FederatedPlanCostEnumerator {
5254 public static void enumerateProgram (DMLProgram prog ) {
55+ FederatedMemoTable memoTable = new FederatedMemoTable ();
56+ Map <String , Long > transTable = new HashMap <>();
57+
5358 for (StatementBlock sb : prog .getStatementBlocks ())
54- enumerateStatementBlock (sb );
59+ enumerateStatementBlock (sb , memoTable , transTable );
5560 }
5661
5762 /**
@@ -61,7 +66,7 @@ public static void enumerateProgram(DMLProgram prog) {
6166 *
6267 * @param sb The statement block to enumerate.
6368 */
64- public static void enumerateStatementBlock (StatementBlock sb ) {
69+ public static void enumerateStatementBlock (StatementBlock sb , FederatedMemoTable memoTable , Map < String , Long > transTable ) {
6570 // While enumerating the program, recursively determine the optimal FedPlan and MemoTable
6671 // for each statement block and statement.
6772 // 1. How to recursively integrate optimal FedPlans and MemoTables across statements and statement blocks?
@@ -75,12 +80,12 @@ public static void enumerateStatementBlock(StatementBlock sb) {
7580 IfStatementBlock isb = (IfStatementBlock ) sb ;
7681 IfStatement istmt = (IfStatement )isb .getStatement (0 );
7782
78- enumerateFederatedPlanCost (isb .getPredicateHops ());
83+ enumerateHopDAG (isb .getPredicateHops (), memoTable , transTable );
7984
8085 for (StatementBlock csb : istmt .getIfBody ())
81- enumerateStatementBlock (csb );
86+ enumerateStatementBlock (csb , memoTable , transTable );
8287 for (StatementBlock csb : istmt .getElseBody ())
83- enumerateStatementBlock (csb );
88+ enumerateStatementBlock (csb , memoTable , transTable );
8489
8590 // Todo: 1. apply iteration weight to csbFedPlans (if: 0.5, else: 0.5)
8691 // Todo: 2. Merge predFedPlans
@@ -89,24 +94,24 @@ public static void enumerateStatementBlock(StatementBlock sb) {
8994
9095 ForStatement fstmt = (ForStatement )fsb .getStatement (0 );
9196
92- enumerateFederatedPlanCost (fsb .getFromHops ());
93- enumerateFederatedPlanCost (fsb .getToHops ());
94- enumerateFederatedPlanCost (fsb .getIncrementHops ());
97+ enumerateHopDAG (fsb .getFromHops (), memoTable , transTable );
98+ enumerateHopDAG (fsb .getToHops (), memoTable , transTable );
99+ enumerateHopDAG (fsb .getIncrementHops (), memoTable , transTable );
95100
96101 for (StatementBlock csb : fstmt .getBody ())
97- enumerateStatementBlock (csb );
102+ enumerateStatementBlock (csb , memoTable , transTable );
98103
99104 // Todo: 1. get(predict) # of Iterations
100105 // Todo: 2. apply iteration weight to csbFedPlans
101106 // Todo: 3. Merge csbFedPlans and predFedPlans
102107 } else if (sb instanceof WhileStatementBlock ) {
103108 WhileStatementBlock wsb = (WhileStatementBlock ) sb ;
104109 WhileStatement wstmt = (WhileStatement )wsb .getStatement (0 );
105- enumerateFederatedPlanCost (wsb .getPredicateHops ());
110+ enumerateHopDAG (wsb .getPredicateHops (), memoTable , transTable );
106111
107112 ArrayList <FedPlan > csbFedPlans = new ArrayList <>();
108113 for (StatementBlock csb : wstmt .getBody ())
109- enumerateStatementBlock (csb );
114+ enumerateStatementBlock (csb , memoTable , transTable );
110115
111116 // Todo: 1. get(predict) # of Iterations
112117 // Todo: 2. apply iteration weight to csbFedPlans
@@ -115,13 +120,13 @@ public static void enumerateStatementBlock(StatementBlock sb) {
115120 FunctionStatementBlock fsb = (FunctionStatementBlock )sb ;
116121 FunctionStatement fstmt = (FunctionStatement )fsb .getStatement (0 );
117122 for (StatementBlock csb : fstmt .getBody ())
118- enumerateStatementBlock (csb );
123+ enumerateStatementBlock (csb , memoTable , transTable );
119124
120125 // Todo: 1. Merge csbFedPlans
121126 } else { //generic (last-level)
122127 if ( sb .getHops () != null )
123128 for ( Hop c : sb .getHops () )
124- enumerateFederatedPlanCost ( c );
129+ enumerateHopDAG ( c , memoTable , transTable );
125130 }
126131 }
127132
@@ -133,12 +138,11 @@ public static void enumerateStatementBlock(StatementBlock sb) {
133138 * @param rootHop The root Hop node from which to start the plan enumeration.
134139 * @return The optimal FedPlan with the minimum cost for the entire DAG.
135140 */
136- public static FedPlan enumerateFederatedPlanCost (Hop rootHop ) {
141+ public static FedPlan enumerateHopDAG (Hop rootHop , FederatedMemoTable memoTable , Map < String , Long > transTable ) {
137142 // Create new memo table to store all plan variants
138- FederatedMemoTable memoTable = new FederatedMemoTable ();
139143
140144 // Recursively enumerate all possible plans
141- enumerateFederatedPlanCost (rootHop , memoTable );
145+ enumerateHop (rootHop , memoTable , transTable );
142146
143147 // Return the minimum cost plan for the root node
144148 FedPlan optimalPlan = getMinCostRootFedPlan (rootHop .getHopID (), memoTable );
@@ -167,14 +171,29 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {
167171 * @param hop ?
168172 * @param memoTable ?
169173 */
170- private static void enumerateFederatedPlanCost (Hop hop , FederatedMemoTable memoTable ) {
174+ private static void enumerateHop (Hop hop , FederatedMemoTable memoTable , Map < String , Long > transTable ) {
171175 int numInputs = hop .getInput ().size ();
172176
173177 // Process all input nodes first if not already in memo table
174178 for (Hop inputHop : hop .getInput ()) {
175179 if (!memoTable .contains (inputHop .getHopID (), FederatedOutput .FOUT )
176180 && !memoTable .contains (inputHop .getHopID (), FederatedOutput .LOUT )) {
177- enumerateFederatedPlanCost (inputHop , memoTable );
181+ enumerateHop (inputHop , memoTable , transTable );
182+ }
183+ }
184+
185+ if (hop instanceof DataOp
186+ && ((DataOp )hop ).getOp ()== Types .OpOpData .TRANSIENTWRITE
187+ && !(hop .getName ().equals ("__pred" ))){
188+ transTable .put (hop .getName (), hop .getHopID ());
189+ }
190+
191+ if (hop instanceof DataOp
192+ && !(hop .getName ().equals ("__pred" ))){
193+ if (((DataOp )hop ).getOp ()== Types .OpOpData .TRANSIENTWRITE ){
194+ transTable .put (hop .getName (), hop .getHopID ());
195+ } else if (((DataOp )hop ).getOp ()== Types .OpOpData .TRANSIENTREAD ){
196+ long rWriteHopID = transTable .get (hop .getName ());
178197 }
179198 }
180199
0 commit comments