Skip to content

Commit b5b6f37

Browse files
committed
[SYSTEMDS-3765] Fix time displacement through function hoisting
This patch fixes issues with time() functions which are used to measure execution time of parts of a program. When these functions were used in expressions (e.g., print string concatenation) the normal DAG compilation might move them before the operation that was actually measured. Similar to DML function calls, we now hoist these time functions out of expressions.
1 parent 34a6571 commit b5b6f37

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

src/main/java/org/apache/sysds/parser/StatementBlock.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,13 @@ else if( expr instanceof BuiltinFunctionExpression ) {
595595
Expression[] clexpr = lexpr.getAllExpr();
596596
for( int i=0; i<clexpr.length; i++ )
597597
clexpr[i] = rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp, prog);
598+
if( !root && lexpr.getOpCode()==Builtins.TIME ) { //core time hoisting
599+
String varname = StatementBlockRewriteRule.createCutVarName(true);
600+
DataIdentifier di = new DataIdentifier(varname);
601+
di.setDataType(lexpr.getDataType());
602+
di.setValueType(lexpr.getValueType());
603+
tmp.add(new AssignmentStatement(di, lexpr, di));
604+
}
598605
}
599606
else if( expr instanceof ParameterizedBuiltinFunctionExpression ) {
600607
ParameterizedBuiltinFunctionExpression lexpr = (ParameterizedBuiltinFunctionExpression) expr;
@@ -612,7 +619,7 @@ else if( expr instanceof FunctionCallIdentifier ) {
612619
FunctionCallIdentifier fexpr = (FunctionCallIdentifier) expr;
613620
for( ParameterExpression pexpr : fexpr.getParamExprs() )
614621
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp, prog));
615-
if( !root ) { //core hoisting
622+
if( !root ) { //core fcall hoisting
616623
String varname = StatementBlockRewriteRule.createCutVarName(true);
617624
DataIdentifier di = new DataIdentifier(varname);
618625
di.setDataType(fexpr.getDataType());

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.sysds.test.functions.rewrite;
2121

22+
import org.junit.Assert;
2223
import org.junit.Test;
2324
import org.apache.sysds.common.Types.ExecMode;
2425
import org.apache.sysds.common.Types.ExecType;
@@ -43,9 +44,14 @@ public void setUp() {
4344
}
4445

4546
@Test
46-
public void testTimeHoisting() {
47+
public void testTimeHoistingCP() {
4748
test(TEST_NAME1, ExecType.CP);
4849
}
50+
51+
@Test
52+
public void testTimeHoistingSpark() {
53+
test(TEST_NAME1, ExecType.SPARK);
54+
}
4955

5056
private void test(String testname, ExecType et)
5157
{
@@ -58,11 +64,15 @@ private void test(String testname, ExecType et)
5864

5965
String HOME = SCRIPT_DIR + TEST_DIR;
6066
fullDMLScriptName = HOME + testname + ".dml";
61-
programArgs = new String[] { "-explain", "-args",
67+
programArgs = new String[] {"-args",
6268
String.valueOf(rows), String.valueOf(cols) };
63-
64-
//FIXME need to hoist time() out of expression similar to function calls
65-
runTest(true, false, null, -1);
69+
70+
//test that time is not executed before 1k-by-1k rand
71+
setOutputBuffering(true);
72+
String out = runTest(true, false, null, -1).toString();
73+
double time = Double.parseDouble(out.split(";")[1]);
74+
System.out.println("Time = "+time+"s");
75+
Assert.assertTrue(time>0.001);
6676
}
6777
finally {
6878
resetExecMode(platformOld);

src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@
2222
t1 = time();
2323
X = rand(rows=$1, cols=$2);
2424

25-
print("time = "+(time()-t1)/1e9+"s"+" "+sum(X));
25+
print(";"+(time()-t1)/1e9+";"+" "+sum(X));
2626

0 commit comments

Comments
 (0)