Skip to content

Commit 13611b0

Browse files
kev-innmboehm7
authored andcommitted
[SYSTEMDS-3018] Support function calls in FedAll and Heuristic planners
Closes #1666.
1 parent 8cb230d commit 13611b0

File tree

3 files changed

+40
-39
lines changed

3 files changed

+40
-39
lines changed

src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
package org.apache.sysds.hops.fedplanner;
2121

22-
import java.util.Collections;
2322
import java.util.HashMap;
2423
import java.util.Map;
2524

2625
import org.apache.sysds.common.Types.ExecType;
2726
import org.apache.sysds.common.Types.OpOpData;
2827
import org.apache.sysds.hops.DataOp;
28+
import org.apache.sysds.hops.FunctionOp;
2929
import org.apache.sysds.hops.Hop;
3030
import org.apache.sysds.hops.fedplanner.FTypes.FType;
3131
import org.apache.sysds.hops.ipa.FunctionCallGraph;
@@ -89,14 +89,14 @@ private void rRewriteStatementBlock(StatementBlock sb, Map<String, FType> fedVar
8989
else if (sb instanceof WhileStatementBlock) {
9090
WhileStatementBlock wsb = (WhileStatementBlock) sb;
9191
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
92-
rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), Collections.emptyMap());
92+
rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
9393
for (StatementBlock csb : wstmt.getBody())
9494
rRewriteStatementBlock(csb, fedVars);
9595
}
9696
else if (sb instanceof IfStatementBlock) {
9797
IfStatementBlock isb = (IfStatementBlock) sb;
9898
IfStatement istmt = (IfStatement)isb.getStatement(0);
99-
rRewriteHop(isb.getPredicateHops(), new HashMap<>(), Collections.emptyMap());
99+
rRewriteHop(isb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
100100
for (StatementBlock csb : istmt.getIfBody())
101101
rRewriteStatementBlock(csb, fedVars);
102102
for (StatementBlock csb : istmt.getElseBody())
@@ -105,9 +105,9 @@ else if (sb instanceof IfStatementBlock) {
105105
else if (sb instanceof ForStatementBlock) { //incl parfor
106106
ForStatementBlock fsb = (ForStatementBlock) sb;
107107
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
108-
rRewriteHop(fsb.getFromHops(), new HashMap<>(), Collections.emptyMap());
109-
rRewriteHop(fsb.getToHops(), new HashMap<>(), Collections.emptyMap());
110-
rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), Collections.emptyMap());
108+
rRewriteHop(fsb.getFromHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
109+
rRewriteHop(fsb.getToHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
110+
rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
111111
for (StatementBlock csb : fstmt.getBody())
112112
rRewriteStatementBlock(csb, fedVars);
113113
}
@@ -117,9 +117,7 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
117117
Map<Long, FType> fedHops = new HashMap<>();
118118
if( sb.getHops() != null )
119119
for( Hop c : sb.getHops() )
120-
rRewriteHop(c, fedHops, fedVars);
121-
122-
//TODO handle function calls
120+
rRewriteHop(c, fedHops, fedVars, sb.getDMLProg());
123121

124122
//propagate federated outputs across DAGs
125123
if( sb.getHops() != null )
@@ -129,19 +127,31 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
129127
}
130128
}
131129

132-
private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, FType> fedVars) {
133-
if( memo.containsKey(hop.getHopID()) )
130+
private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, FType> fedVars, DMLProgram program) {
131+
if( hop == null || memo.containsKey(hop.getHopID()) )
134132
return; //already processed
135133

136134
//process children first
137135
for( Hop c : hop.getInput() )
138-
rRewriteHop(c, memo, fedVars);
136+
rRewriteHop(c, memo, fedVars, program);
139137

140138
//handle specific operators (except transient writes)
141-
if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
139+
if(hop instanceof FunctionOp) {
140+
String funcName = ((FunctionOp) hop).getFunctionName();
141+
String funcNamespace = ((FunctionOp) hop).getFunctionNamespace();
142+
FunctionStatementBlock sbFuncBlock = program.getFunctionDictionary(funcNamespace).getFunction(funcName);
143+
FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0);
144+
145+
Map<String, FType> funcFedVars = createFunctionFedVarTable((FunctionOp) hop, memo);
146+
rRewriteStatementBlock(sbFuncBlock, funcFedVars);
147+
mapFunctionOutputs((FunctionOp) hop, funcStatement, funcFedVars, fedVars);
148+
}
149+
else if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
142150
memo.put(hop.getHopID(), deriveFType((DataOp)hop));
143151
else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
144152
memo.put(hop.getHopID(), fedVars.get(hop.getName()));
153+
else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) )
154+
fedVars.put(hop.getName(), memo.get(hop.getHopID()));
145155
else if( allowsFederated(hop, memo) ) {
146156
hop.setForcedExecType(ExecType.FED);
147157
memo.put(hop.getHopID(), getFederatedOut(hop, memo));
@@ -151,4 +161,21 @@ else if( allowsFederated(hop, memo) ) {
151161
else // memoization as processed, but not federated
152162
memo.put(hop.getHopID(), null);
153163
}
164+
165+
static private Map<String, FType> createFunctionFedVarTable(FunctionOp hop, Map<Long, FType> memo) {
166+
Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap(hop);
167+
Map<String, FType> funcFedVars = new HashMap<>();
168+
funcParamMap.forEach((key, value) -> {
169+
funcFedVars.put(key, memo.get(value.getHopID()));
170+
});
171+
return funcFedVars;
172+
}
173+
174+
private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement,
175+
Map<String, FType> funcFedVars, Map<String, FType> callFedVars) {
176+
for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) {
177+
FType outputFType = funcFedVars.get(funcStatement.getOutputParams().get(i).getName());
178+
callFedVars.put(sbHop.getOutputVariableNames()[i], outputFType);
179+
}
180+
}
154181
}

src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.apache.sysds.hops.FunctionOp;
2323
import org.apache.sysds.hops.Hop;
24-
import org.apache.sysds.parser.FunctionStatement;
2524
import org.apache.sysds.runtime.DMLRuntimeException;
2625
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
2726

@@ -76,26 +75,4 @@ public static Map<String,Hop> getParamMap(FunctionOp funcOp){
7675
}
7776
return paramMap;
7877
}
79-
80-
/**
81-
* Saves the HOPs (TWrite) of the function return values for
82-
* the variable name used when calling the function.
83-
*
84-
* Example:
85-
* <code>
86-
* f = function() return (matrix[double] model) {a = rand(1, 1);}
87-
* b = f();
88-
* </code>
89-
* This function saves the HOP writing to <code>a</code> for identifier <code>b</code>.
90-
*
91-
* @param sbHop The <code>FunctionOp</code> for the call
92-
* @param funcStatement The <code>FunctionStatement</code> of the called function
93-
* @param transientWrites map of transient writes
94-
*/
95-
public static void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement, Map<String,Hop> transientWrites) {
96-
for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i) {
97-
Hop outputWrite = transientWrites.get(funcStatement.getOutputParams().get(i).getName());
98-
transientWrites.put(sbHop.getOutputVariableNames()[i], outputWrite);
99-
}
100-
}
10178
}

src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.apache.sysds.test.AutomatedTestBase;
2828
import org.apache.sysds.test.TestConfiguration;
2929
import org.apache.sysds.test.TestUtils;
30-
import org.junit.Ignore;
3130
import org.junit.Test;
3231

3332
import java.io.File;
@@ -72,7 +71,6 @@ public void runL2SVMHeuristicTest(){
7271
}
7372

7473
@Test
75-
@Ignore //TODO
7674
public void runL2SVMFunctionFOUTTest(){
7775
String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*",
7876
"fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
@@ -81,7 +79,6 @@ public void runL2SVMFunctionFOUTTest(){
8179
}
8280

8381
@Test
84-
@Ignore //TODO
8582
public void runL2SVMFunctionHeuristicTest(){
8683
String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"};
8784
setTestConf("SystemDS-config-heuristic.xml");

0 commit comments

Comments
 (0)