2525import org .apache .sysds .test .AutomatedTestBase ;
2626import org .apache .sysds .test .TestConfiguration ;
2727import org .apache .sysds .test .TestUtils ;
28+ import org .junit .Assert ;
2829import org .junit .Test ;
2930
3031import java .util .HashMap ;
@@ -37,7 +38,7 @@ public class RewriteSimplifyBushyBinaryOperationTest extends AutomatedTestBase {
3738 TEST_DIR + RewriteSimplifyBushyBinaryOperationTest .class .getSimpleName () + "/" ;
3839
3940 private static final int rows = 500 ;
40- private static final int cols = 500 ;
41+ private static final int cols = 100 ;
4142 private static final double eps = Math .pow (10 , -10 );
4243
4344 @ Override
@@ -46,28 +47,28 @@ public void setUp() {
4647 addTestConfiguration (TEST_NAME , new TestConfiguration (TEST_CLASS_DIR , TEST_NAME , new String [] {"R" }));
4748 }
4849
50+ //pattern: t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%((X*Y)*(Z%*%v))
4951 @ Test
5052 public void testBushyBinaryOperationMultNoRewrite () {
5153 testSimplifyBushyBinaryOperation (1 , false );
5254 }
5355
5456 @ Test
55- public void testBushyBinaryOperationMultRewrite () { //pattern: (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
57+ public void testBushyBinaryOperationMultRewrite () {
5658 testSimplifyBushyBinaryOperation (1 , true );
5759 }
5860
61+ //pattern: t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
5962 @ Test
6063 public void testBushyBinaryOperationAddNoRewrite () {
6164 testSimplifyBushyBinaryOperation (2 , false );
6265 }
6366
6467 @ Test
65- public void testBushyBinaryOperationAddtRewrite () { //pattern: (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
68+ public void testBushyBinaryOperationAddtRewrite () {
6669 testSimplifyBushyBinaryOperation (2 , true );
6770 }
6871
69-
70-
7172 private void testSimplifyBushyBinaryOperation (int ID , boolean rewrites ) {
7273 boolean oldFlag = OptimizerUtils .ALLOW_ALGEBRAIC_SIMPLIFICATION ;
7374 try {
@@ -76,19 +77,21 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
7677
7778 String HOME = SCRIPT_DIR + TEST_DIR ;
7879 fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
79- programArgs = new String [] {"-stats" , "-args" , input ("X" ), input ("Y" ), input ("Z" ), input ("v" ), String .valueOf (ID ), output ("R" )};
80+ programArgs = new String [] {"-stats" , "-explain" , "-args" ,
81+ input ("X" ), input ("Y" ), input ("Z" ), input ("v" ), String .valueOf (ID ), output ("R" )};
8082 fullRScriptName = HOME + TEST_NAME + ".R" ;
8183 rCmd = getRCmd (inputDir (), String .valueOf (ID ), expectedDir ());
8284
8385 OptimizerUtils .ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites ;
84- //OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
85- //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
86-
86+ OptimizerUtils .ALLOW_SUM_PRODUCT_REWRITES2 = false ; //disable nary mult
87+ OptimizerUtils .ALLOW_OPERATOR_FUSION = false ; //disable emult reordering
88+ //TODO improved phase ordering
89+
8790 //create matrices
88- double [][] X = getRandomMatrix (rows , cols , -1 , 1 , 0.60d , 3 );
89- double [][] Y = getRandomMatrix (rows , cols , -1 , 1 , 0.60d , 5 );
91+ double [][] X = getRandomMatrix (rows , 1 , -1 , 1 , 0.60d , 3 );
92+ double [][] Y = getRandomMatrix (rows , 1 , -1 , 1 , 0.60d , 5 );
9093 double [][] Z = getRandomMatrix (rows , cols , -1 , 1 , 0.60d , 6 );
91- double [][] v = getRandomMatrix (rows , cols , -1 , 1 , 0.60d , 8 );
94+ double [][] v = getRandomMatrix (cols , 1 , -1 , 1 , 0.60d , 8 );
9295 writeInputMatrixWithMTD ("X" , X , true );
9396 writeInputMatrixWithMTD ("Y" , Y , true );
9497 writeInputMatrixWithMTD ("Z" , Z , true );
@@ -101,15 +104,14 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
101104 HashMap <MatrixValue .CellIndex , Double > dmlfile = readDMLMatrixFromOutputDir ("R" );
102105 HashMap <MatrixValue .CellIndex , Double > rfile = readRMatrixFromExpectedDir ("R" );
103106 TestUtils .compareMatrices (dmlfile , rfile , eps , "Stat-DML" , "Stat-R" );
104-
105- /**
106- * The rewrite in RewriteAlgebraicSimplificationStatic is not entered. Hence, we fail
107- * the assertions for this rewrite so that we can revisit this issue later.
108- */
109- //FIXME
107+
108+ if ( ID == 1 && rewrites ) //check mmchain, enabled by bushy join
109+ Assert .assertTrue (heavyHittersContainsString ("mmchain" ));
110110 }
111111 finally {
112112 OptimizerUtils .ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag ;
113+ OptimizerUtils .ALLOW_OPERATOR_FUSION = true ;
114+ OptimizerUtils .ALLOW_SUM_PRODUCT_REWRITES2 = true ;
113115 Recompiler .reinitRecompiler ();
114116 }
115117 }
0 commit comments