Skip to content

Commit b716d02

Browse files
committed
[SYSTEMDS-3798] Fix rewrite loop vectorization and code coverage
The existing single loop vectorization rewrite only checked for correct results not if the rewrite was indeed applied. This issue has been found while improving the code coverage of rewrites. When the handling of predicates (with special transient write) was changed in the past, these rewrites were not properly updated. This patch fixes all rewrites and also add tests for all cases. Furthermore, this patch also removes two other unused hop rewrites.
1 parent cee72fc commit b716d02

12 files changed

+357
-330
lines changed

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
7575
if( staticRewrites )
7676
{
7777
//add static HOP DAG rewrite rules
78-
_dagRuleSet.add( new RewriteTransientWriteParentHandling() );
7978
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
8079
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
8180
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java

Lines changed: 109 additions & 122 deletions
Large diffs are not rendered by default.

src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java

Lines changed: 0 additions & 110 deletions
This file was deleted.

src/main/java/org/apache/sysds/hops/rewrite/RewriteTransientWriteParentHandling.java

Lines changed: 0 additions & 76 deletions
This file was deleted.

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,23 @@
2121

2222
import java.util.HashMap;
2323

24+
import org.junit.Assert;
2425
import org.junit.Test;
2526
import org.apache.sysds.hops.OptimizerUtils;
2627
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
2728
import org.apache.sysds.test.AutomatedTestBase;
2829
import org.apache.sysds.test.TestConfiguration;
2930
import org.apache.sysds.test.TestUtils;
31+
import org.apache.sysds.utils.Statistics;
3032

31-
/**
32-
* Regression test for loop vectorization rewrite
33-
* for(i in 1:n) s = s + as.scalar(A[i,1]) -> s = s + sum(A[1:n,1])
34-
*
35-
*/
3633
public class RewriteLoopVectorization extends AutomatedTestBase
3734
{
3835
private static final String TEST_NAME1 = "RewriteLoopVectorizationSum"; //amendable
3936
private static final String TEST_NAME2 = "RewriteLoopVectorizationSum2"; //not amendable
40-
37+
private static final String TEST_NAME3 = "RewriteLoopVectorizationBinary"; //amendable
38+
private static final String TEST_NAME4 = "RewriteLoopVectorizationUnary"; //amendable
39+
private static final String TEST_NAME5 = "RewriteLoopVectorizationIndexedCopy"; //amendable
40+
4141
private static final String TEST_DIR = "functions/rewrite/";
4242
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteLoopVectorization.class.getSimpleName() + "/";
4343

@@ -46,6 +46,9 @@ public void setUp() {
4646
TestUtils.clearAssertionInformation();
4747
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
4848
addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
49+
addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
50+
addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
51+
addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
4952
}
5053

5154
@Test
@@ -68,14 +71,39 @@ public void testLoopVectorizationSum2Rewrite() {
6871
testRewriteLoopVectorizationSum( TEST_NAME2, true );
6972
}
7073

71-
/**
72-
*
73-
* @param testname
74-
* @param rewrites
75-
*/
74+
@Test
75+
public void testLoopVectorizationBinaryNoRewrite() {
76+
testRewriteLoopVectorizationSum( TEST_NAME3, false );
77+
}
78+
79+
@Test
80+
public void testLoopVectorizationBinaryRewrite() {
81+
testRewriteLoopVectorizationSum( TEST_NAME3, true );
82+
}
83+
84+
@Test
85+
public void testLoopVectorizationUnaryNoRewrite() {
86+
testRewriteLoopVectorizationSum( TEST_NAME4, false );
87+
}
88+
89+
@Test
90+
public void testLoopVectorizationUnaryRewrite() {
91+
testRewriteLoopVectorizationSum( TEST_NAME4, true );
92+
}
93+
94+
@Test
95+
public void testLoopVectorizationIndexedCopyNoRewrite() {
96+
testRewriteLoopVectorizationSum( TEST_NAME5, false );
97+
}
98+
99+
@Test
100+
public void testLoopVectorizationIndexedCopyRewrite() {
101+
testRewriteLoopVectorizationSum( TEST_NAME5, true );
102+
}
103+
76104
private void testRewriteLoopVectorizationSum( String testname, boolean rewrites )
77105
{
78-
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
106+
boolean oldFlag = OptimizerUtils.ALLOW_AUTO_VECTORIZATION;
79107

80108
try
81109
{
@@ -84,12 +112,12 @@ private void testRewriteLoopVectorizationSum( String testname, boolean rewrites
84112

85113
String HOME = SCRIPT_DIR + TEST_DIR;
86114
fullDMLScriptName = HOME + testname + ".dml";
87-
programArgs = new String[]{ "-stats","-args", output("Scalar") };
115+
programArgs = new String[]{ "-stats", "-args", output("Scalar") };
88116

89117
fullRScriptName = HOME + testname + ".R";
90-
rCmd = getRCmd(inputDir(), expectedDir());
118+
rCmd = getRCmd(inputDir(), expectedDir());
91119

92-
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
120+
OptimizerUtils.ALLOW_AUTO_VECTORIZATION = rewrites;
93121

94122
runTest(true, false, null, -1);
95123
runRScript(true);
@@ -98,9 +126,11 @@ private void testRewriteLoopVectorizationSum( String testname, boolean rewrites
98126
HashMap<CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("Scalar");
99127
HashMap<CellIndex, Double> rfile = readRScalarFromExpectedDir("Scalar");
100128
TestUtils.compareScalars(dmlfile.toString(), rfile.toString());
129+
if( !testname.equals(TEST_NAME2) && rewrites )
130+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("rightIndex") <= 2);
101131
}
102132
finally {
103-
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
104-
}
105-
}
106-
}
133+
OptimizerUtils.ALLOW_AUTO_VECTORIZATION = oldFlag;
134+
}
135+
}
136+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
#-------------------------------------------------------------
19+
20+
args<-commandArgs(TRUE)
21+
22+
A = matrix(7.0, 10, 10)
23+
B = matrix(3.0, 10, 10)
24+
n = nrow(A)
25+
s = 0.0
26+
27+
X = A
28+
for(i in 2:(n-1)) {
29+
X[i,2] = A[i,1] + B[i,3]
30+
}
31+
s = sum(X)
32+
33+
write(s, paste(args[2], "Scalar",sep=""))
34+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
A = matrix(7.0, 10, 10)
23+
B = matrix(3.0, 10, 10)
24+
n = nrow(A)
25+
s = 0.0
26+
27+
X = A
28+
for(i in 2:n-1) {
29+
X[i,2] = A[i,1] + B[i,3]
30+
}
31+
s = sum(X)
32+
33+
write(s, $1)

0 commit comments

Comments
 (0)