Skip to content

Commit c86aa0a

Browse files
ReneEnjilianmboehm7
authored andcommitted
[SYSTEMDS-3774] Improved test coverage of simplification rewrites
Closes #2109.
1 parent 2ce1910 commit c86aa0a

File tree

70 files changed

+4426
-37
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+4426
-37
lines changed

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ public enum MemoryManager {
195195
* all sum-product related rewrites.
196196
*/
197197
public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;
198+
199+
/**
200+
* Enables additional mmchain optimizations. in the future, this might be merged with
201+
* ALLOW_SUM_PRODUCT_REWRITES.
202+
*/
203+
public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false;
198204

199205
/**
200206
* Enables a specific hop dag rewrite that splits hop dags after csv persistent reads with

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
127127
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
128128
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
129129
}
130+
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
131+
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
132+
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationSparse() ); //dependency: cse
133+
}
130134
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) {
131135
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
132136
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,7 +1635,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
16351635
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
16361636
tX = HopRewriteUtils.createTranspose(tX);
16371637
}
1638-
else
1638+
else
16391639
tX = tX.getInput().get(0);
16401640

16411641
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
@@ -1664,7 +1664,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
16641664
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
16651665
tX = HopRewriteUtils.createTranspose(tX);
16661666
}
1667-
else
1667+
else
16681668
tX = tX.getInput().get(0);
16691669

16701670
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
@@ -1690,7 +1690,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
16901690
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
16911691
tX = HopRewriteUtils.createTranspose(tX);
16921692
}
1693-
else
1693+
else
16941694
tX = tX.getInput().get(0);
16951695

16961696
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
@@ -1722,7 +1722,7 @@ private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
17221722
if( !HopRewriteUtils.isTransposeOperation(tX) ) {
17231723
tX = HopRewriteUtils.createTranspose(tX);
17241724
}
1725-
else
1725+
else
17261726
tX = tX.getInput().get(0);
17271727

17281728
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
@@ -2157,7 +2157,7 @@ private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
21572157

21582158
if( !HopRewriteUtils.isTransposeOperation(V) )
21592159
V = HopRewriteUtils.createTranspose(V);
2160-
else
2160+
else
21612161
V = V.getInput().get(0);
21622162

21632163
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
@@ -2251,7 +2251,7 @@ else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp
22512251

22522252
if( !HopRewriteUtils.isTransposeOperation(V) )
22532253
V = HopRewriteUtils.createTranspose(V);
2254-
else
2254+
else
22552255
V = V.getInput().get(0);
22562256

22572257
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.hops.OptimizerUtils;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
import java.util.HashMap;
31+
32+
public class RewriteFuseBinarySubDAGToUnaryOperationTest extends AutomatedTestBase {
33+
34+
private static final String TEST_NAME = "RewriteFuseBinarySubDAGToUnaryOperation";
35+
private static final String TEST_DIR = "functions/rewrite/";
36+
private static final String TEST_CLASS_DIR =
37+
TEST_DIR + RewriteFuseBinarySubDAGToUnaryOperationTest.class.getSimpleName() + "/";
38+
39+
private static final int rows = 300;
40+
private static final int cols = 200;
41+
private static final double eps = Math.pow(10, -10);
42+
43+
@Override
44+
public void setUp() {
45+
TestUtils.clearAssertionInformation();
46+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
47+
}
48+
49+
@Test
50+
public void testSampleProportionLeftNoRewrite(){
51+
testSimplifyDistributiveBinaryOperation(1, false);
52+
}
53+
54+
@Test
55+
public void testSampleProportionLeftRewrite(){
56+
testSimplifyDistributiveBinaryOperation(1, true); //pattern: (1-X)*X -> sprop(X)
57+
}
58+
59+
@Test
60+
public void testSampleProportionRightNoRewrite(){
61+
testSimplifyDistributiveBinaryOperation(2, false);
62+
}
63+
64+
@Test
65+
public void testSampleProportionRightRewrite(){
66+
testSimplifyDistributiveBinaryOperation(2, true); //pattern: X*(1-X) -> sprop(X)
67+
}
68+
69+
@Test
70+
public void testFuseBinarySubDAGToUnarySigmoidNoRewrite(){
71+
testSimplifyDistributiveBinaryOperation(3, false);
72+
}
73+
74+
@Test
75+
public void testFuseBinarySubDAGToUnarySigmoidRewrite(){
76+
testSimplifyDistributiveBinaryOperation(3, true); //pattern: 1/(1+exp(-X)) -> sigmoid(X)
77+
}
78+
79+
80+
private void testSimplifyDistributiveBinaryOperation(int ID, boolean rewrites) {
81+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
82+
try {
83+
TestConfiguration config = getTestConfiguration(TEST_NAME);
84+
loadTestConfiguration(config);
85+
86+
String HOME = SCRIPT_DIR + TEST_DIR;
87+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
88+
programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")};
89+
fullRScriptName = HOME + TEST_NAME + ".R";
90+
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());
91+
92+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
93+
94+
//create dense matrix so that rewrites are possible
95+
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
96+
writeInputMatrixWithMTD("X", X, true);
97+
98+
runTest(true, false, null, -1);
99+
runRScript(true);
100+
101+
//compare matrices
102+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
103+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
104+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
105+
106+
if (rewrites)
107+
Assert.assertTrue(heavyHittersContainsString("sprop") || heavyHittersContainsString("sigmoid"));
108+
else
109+
Assert.assertFalse(heavyHittersContainsString("sprop") || heavyHittersContainsString("sigmoid"));
110+
111+
112+
}
113+
finally {
114+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
115+
}
116+
}
117+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.hops.OptimizerUtils;
23+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
24+
import org.apache.sysds.test.AutomatedTestBase;
25+
import org.apache.sysds.test.TestConfiguration;
26+
import org.apache.sysds.test.TestUtils;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
import java.util.HashMap;
31+
32+
public class RewriteFuseLeftIndexingChainToAppendTest extends AutomatedTestBase {
33+
private static final String TEST_NAME = "RewriteFuseLeftIndexingChainToAppend";
34+
private static final String TEST_DIR = "functions/rewrite/";
35+
private static final String TEST_CLASS_DIR =
36+
TEST_DIR + RewriteFuseLeftIndexingChainToAppendTest.class.getSimpleName() + "/";
37+
38+
private static final int rows = 300;
39+
private static final int cols = 1;
40+
private static final double eps = Math.pow(10, -10);
41+
42+
@Override
43+
public void setUp() {
44+
TestUtils.clearAssertionInformation();
45+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
46+
}
47+
48+
@Test
49+
public void testFuseLeftIndexingChainColumnNoRewrite() {
50+
testRewriteFuseLeftIndexingChainToAppend(1, false);
51+
}
52+
53+
@Test
54+
public void testFuseLeftIndexingChainColumnRewrite() {
55+
testRewriteFuseLeftIndexingChainToAppend(1, true);
56+
}
57+
58+
@Test
59+
public void testFuseLeftIndexingChainRowNoRewrite() {
60+
testRewriteFuseLeftIndexingChainToAppend(2, false);
61+
}
62+
63+
@Test
64+
public void testFuseLeftIndexingChainRowRewrite() {
65+
testRewriteFuseLeftIndexingChainToAppend(2, true);
66+
}
67+
68+
private void testRewriteFuseLeftIndexingChainToAppend(int ID, boolean rewrites) {
69+
boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
70+
boolean oldFlag2 = OptimizerUtils.ALLOW_OPERATOR_FUSION;
71+
try {
72+
TestConfiguration config = getTestConfiguration(TEST_NAME);
73+
loadTestConfiguration(config);
74+
75+
String HOME = SCRIPT_DIR + TEST_DIR;
76+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
77+
programArgs = new String[] {"-stats", "-args", input("A"), input("B"), String.valueOf(ID), output("R")};
78+
fullRScriptName = HOME + TEST_NAME + ".R";
79+
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());
80+
81+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
82+
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
83+
84+
//create matrices
85+
double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
86+
double[][] B = getRandomMatrix(rows, cols, -1, 1, 0.80d, 5);
87+
writeInputMatrixWithMTD("A", A, true);
88+
writeInputMatrixWithMTD("B", B, true);
89+
90+
runTest(true, false, null, -1);
91+
runRScript(true);
92+
93+
//compare matrices
94+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
95+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
96+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
97+
98+
if(rewrites)
99+
Assert.assertTrue(heavyHittersContainsString("append"));
100+
else
101+
Assert.assertTrue(heavyHittersContainsString("leftIndex"));
102+
103+
}
104+
finally {
105+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag1;
106+
OptimizerUtils.ALLOW_OPERATOR_FUSION = oldFlag2;
107+
}
108+
109+
}
110+
}

0 commit comments

Comments
 (0)