1919
2020package org .apache .sysds .hops .fedplanner ;
2121
22- import java .util .Collections ;
2322import java .util .HashMap ;
2423import java .util .Map ;
2524
2625import org .apache .sysds .common .Types .ExecType ;
2726import org .apache .sysds .common .Types .OpOpData ;
2827import org .apache .sysds .hops .DataOp ;
28+ import org .apache .sysds .hops .FunctionOp ;
2929import org .apache .sysds .hops .Hop ;
3030import org .apache .sysds .hops .fedplanner .FTypes .FType ;
3131import org .apache .sysds .hops .ipa .FunctionCallGraph ;
@@ -89,14 +89,14 @@ private void rRewriteStatementBlock(StatementBlock sb, Map<String, FType> fedVar
8989 else if (sb instanceof WhileStatementBlock ) {
9090 WhileStatementBlock wsb = (WhileStatementBlock ) sb ;
9191 WhileStatement wstmt = (WhileStatement )wsb .getStatement (0 );
92- rRewriteHop (wsb .getPredicateHops (), new HashMap <>(), Collections . emptyMap ());
92+ rRewriteHop (wsb .getPredicateHops (), new HashMap <>(), new HashMap <>(), sb . getDMLProg ());
9393 for (StatementBlock csb : wstmt .getBody ())
9494 rRewriteStatementBlock (csb , fedVars );
9595 }
9696 else if (sb instanceof IfStatementBlock ) {
9797 IfStatementBlock isb = (IfStatementBlock ) sb ;
9898 IfStatement istmt = (IfStatement )isb .getStatement (0 );
99- rRewriteHop (isb .getPredicateHops (), new HashMap <>(), Collections . emptyMap ());
99+ rRewriteHop (isb .getPredicateHops (), new HashMap <>(), new HashMap <>(), sb . getDMLProg ());
100100 for (StatementBlock csb : istmt .getIfBody ())
101101 rRewriteStatementBlock (csb , fedVars );
102102 for (StatementBlock csb : istmt .getElseBody ())
@@ -105,9 +105,9 @@ else if (sb instanceof IfStatementBlock) {
105105 else if (sb instanceof ForStatementBlock ) { //incl parfor
106106 ForStatementBlock fsb = (ForStatementBlock ) sb ;
107107 ForStatement fstmt = (ForStatement )fsb .getStatement (0 );
108- rRewriteHop (fsb .getFromHops (), new HashMap <>(), Collections . emptyMap ());
109- rRewriteHop (fsb .getToHops (), new HashMap <>(), Collections . emptyMap ());
110- rRewriteHop (fsb .getIncrementHops (), new HashMap <>(), Collections . emptyMap ());
108+ rRewriteHop (fsb .getFromHops (), new HashMap <>(), new HashMap <>(), sb . getDMLProg ());
109+ rRewriteHop (fsb .getToHops (), new HashMap <>(), new HashMap <>(), sb . getDMLProg ());
110+ rRewriteHop (fsb .getIncrementHops (), new HashMap <>(), new HashMap <>(), sb . getDMLProg ());
111111 for (StatementBlock csb : fstmt .getBody ())
112112 rRewriteStatementBlock (csb , fedVars );
113113 }
@@ -117,9 +117,7 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
117117 Map <Long , FType > fedHops = new HashMap <>();
118118 if ( sb .getHops () != null )
119119 for ( Hop c : sb .getHops () )
120- rRewriteHop (c , fedHops , fedVars );
121-
122- //TODO handle function calls
120+ rRewriteHop (c , fedHops , fedVars , sb .getDMLProg ());
123121
124122 //propagate federated outputs across DAGs
125123 if ( sb .getHops () != null )
@@ -129,19 +127,31 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
129127 }
130128 }
131129
132- private void rRewriteHop (Hop hop , Map <Long , FType > memo , Map <String , FType > fedVars ) {
133- if ( memo .containsKey (hop .getHopID ()) )
130+ private void rRewriteHop (Hop hop , Map <Long , FType > memo , Map <String , FType > fedVars , DMLProgram program ) {
131+ if ( hop == null || memo .containsKey (hop .getHopID ()) )
134132 return ; //already processed
135133
136134 //process children first
137135 for ( Hop c : hop .getInput () )
138- rRewriteHop (c , memo , fedVars );
136+ rRewriteHop (c , memo , fedVars , program );
139137
140138 //handle specific operators (except transient writes)
141- if ( HopRewriteUtils .isData (hop , OpOpData .FEDERATED ) )
139+ if (hop instanceof FunctionOp ) {
140+ String funcName = ((FunctionOp ) hop ).getFunctionName ();
141+ String funcNamespace = ((FunctionOp ) hop ).getFunctionNamespace ();
142+ FunctionStatementBlock sbFuncBlock = program .getFunctionDictionary (funcNamespace ).getFunction (funcName );
143+ FunctionStatement funcStatement = (FunctionStatement ) sbFuncBlock .getStatement (0 );
144+
145+ Map <String , FType > funcFedVars = createFunctionFedVarTable ((FunctionOp ) hop , memo );
146+ rRewriteStatementBlock (sbFuncBlock , funcFedVars );
147+ mapFunctionOutputs ((FunctionOp ) hop , funcStatement , funcFedVars , fedVars );
148+ }
149+ else if ( HopRewriteUtils .isData (hop , OpOpData .FEDERATED ) )
142150 memo .put (hop .getHopID (), deriveFType ((DataOp )hop ));
143151 else if ( HopRewriteUtils .isData (hop , OpOpData .TRANSIENTREAD ) )
144152 memo .put (hop .getHopID (), fedVars .get (hop .getName ()));
153+ else if ( HopRewriteUtils .isData (hop , OpOpData .TRANSIENTWRITE ) )
154+ fedVars .put (hop .getName (), memo .get (hop .getHopID ()));
145155 else if ( allowsFederated (hop , memo ) ) {
146156 hop .setForcedExecType (ExecType .FED );
147157 memo .put (hop .getHopID (), getFederatedOut (hop , memo ));
@@ -151,4 +161,21 @@ else if( allowsFederated(hop, memo) ) {
151161 else // memoization as processed, but not federated
152162 memo .put (hop .getHopID (), null );
153163 }
164+
165+ static private Map <String , FType > createFunctionFedVarTable (FunctionOp hop , Map <Long , FType > memo ) {
166+ Map <String , Hop > funcParamMap = FederatedPlannerUtils .getParamMap (hop );
167+ Map <String , FType > funcFedVars = new HashMap <>();
168+ funcParamMap .forEach ((key , value ) -> {
169+ funcFedVars .put (key , memo .get (value .getHopID ()));
170+ });
171+ return funcFedVars ;
172+ }
173+
174+ private void mapFunctionOutputs (FunctionOp sbHop , FunctionStatement funcStatement ,
175+ Map <String , FType > funcFedVars , Map <String , FType > callFedVars ) {
176+ for (int i = 0 ; i < sbHop .getOutputVariableNames ().length ; ++i ) {
177+ FType outputFType = funcFedVars .get (funcStatement .getOutputParams ().get (i ).getName ());
178+ callFedVars .put (sbHop .getOutputVariableNames ()[i ], outputFType );
179+ }
180+ }
154181}
0 commit comments