Skip to content

Commit 7de3657

Browse files
committed
[SYSTEMDS-3785] Fix rewrite test for simplify bushy binary ops
This patch resolves a remaining FIXME after improved rewrite code coverage by fixing the expressions and other rewrite configs so the test actually triggers the existing rewrite.
1 parent 63b99e5 commit 7de3657

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ public enum MemoryManager {
195195
* all sum-product related rewrites.
196196
*/
197197
public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;
198+
public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true;
198199

199200
/**
200201
* Enables additional mmchain optimizations. in the future, this might be merged with

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
126126
}
127127
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
128128
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
129-
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
129+
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
130+
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()); //dependency: cse
130131
}
131132
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
132133
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int
855855
}
856856

857857
/**
858-
* (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
859-
* (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
858+
* t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
859+
* t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
860860
*
861861
* Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
862862
* eagerly, which would loose additional rewrite potential. This rewrite has two goals

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.test.AutomatedTestBase;
2626
import org.apache.sysds.test.TestConfiguration;
2727
import org.apache.sysds.test.TestUtils;
28+
import org.junit.Assert;
2829
import org.junit.Test;
2930

3031
import java.util.HashMap;
@@ -37,7 +38,7 @@ public class RewriteSimplifyBushyBinaryOperationTest extends AutomatedTestBase {
3738
TEST_DIR + RewriteSimplifyBushyBinaryOperationTest.class.getSimpleName() + "/";
3839

3940
private static final int rows = 500;
40-
private static final int cols = 500;
41+
private static final int cols = 100;
4142
private static final double eps = Math.pow(10, -10);
4243

4344
@Override
@@ -46,28 +47,28 @@ public void setUp() {
4647
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
4748
}
4849

50+
//pattern: t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%((X*Y)*(Z%*%v))
4951
@Test
5052
public void testBushyBinaryOperationMultNoRewrite() {
5153
testSimplifyBushyBinaryOperation(1, false);
5254
}
5355

5456
@Test
55-
public void testBushyBinaryOperationMultRewrite() { //pattern: (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
57+
public void testBushyBinaryOperationMultRewrite() {
5658
testSimplifyBushyBinaryOperation(1, true);
5759
}
5860

61+
//pattern: t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
5962
@Test
6063
public void testBushyBinaryOperationAddNoRewrite() {
6164
testSimplifyBushyBinaryOperation(2, false);
6265
}
6366

6467
@Test
65-
public void testBushyBinaryOperationAddtRewrite() { //pattern: (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
68+
public void testBushyBinaryOperationAddtRewrite() {
6669
testSimplifyBushyBinaryOperation(2, true);
6770
}
6871

69-
70-
7172
private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
7273
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
7374
try {
@@ -76,19 +77,21 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
7677

7778
String HOME = SCRIPT_DIR + TEST_DIR;
7879
fullDMLScriptName = HOME + TEST_NAME + ".dml";
79-
programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID), output("R")};
80+
programArgs = new String[] {"-stats", "-explain", "-args",
81+
input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID), output("R")};
8082
fullRScriptName = HOME + TEST_NAME + ".R";
8183
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());
8284

8385
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
84-
//OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
85-
//OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
86-
86+
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = false; //disable nary mult
87+
OptimizerUtils.ALLOW_OPERATOR_FUSION = false; //disable emult reordering
88+
//TODO improved phase ordering
89+
8790
//create matrices
88-
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3);
89-
double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
91+
double[][] X = getRandomMatrix(rows, 1, -1, 1, 0.60d, 3);
92+
double[][] Y = getRandomMatrix(rows, 1, -1, 1, 0.60d, 5);
9093
double[][] Z = getRandomMatrix(rows, cols, -1, 1, 0.60d, 6);
91-
double[][] v = getRandomMatrix(rows, cols, -1, 1, 0.60d, 8);
94+
double[][] v = getRandomMatrix(cols, 1, -1, 1, 0.60d, 8);
9295
writeInputMatrixWithMTD("X", X, true);
9396
writeInputMatrixWithMTD("Y", Y, true);
9497
writeInputMatrixWithMTD("Z", Z, true);
@@ -101,15 +104,14 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
101104
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
102105
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
103106
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
104-
105-
/**
106-
* The rewrite in RewriteAlgebraicSimplificationStatic is not entered. Hence, we fail
107-
* the assertions for this rewrite so that we can revisit this issue later.
108-
*/
109-
//FIXME
107+
108+
if( ID == 1 && rewrites ) //check mmchain, enabled by bushy join
109+
Assert.assertTrue(heavyHittersContainsString("mmchain"));
110110
}
111111
finally {
112112
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
113+
OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
114+
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = true;
113115
Recompiler.reinitRecompiler();
114116
}
115117
}

0 commit comments

Comments
 (0)