Skip to content

Commit 12d8cd7

Browse files
committed
[SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases
1 parent 9efc4be commit 12d8cd7

File tree

3 files changed

+32
-221
lines changed

3 files changed

+32
-221
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
8585
private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV};
8686

8787
//valid unary and binary operators for wumm
88-
private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT, OpOp1.SIGMOID, OpOp1.SPROP};
89-
private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.POW};
88+
private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{
89+
OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG,
90+
OpOp1.SQRT, OpOp1.SIN, OpOp1.COS, OpOp1.SIGMOID, OpOp1.SPROP};
91+
private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{
92+
OpOp2.MULT, OpOp2.POW};
9093

9194
@Override
9295
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java

Lines changed: 19 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919

2020
package org.apache.sysds.test.functions.rewrite;
2121

22+
import java.util.HashMap;
23+
2224
import org.apache.sysds.hops.OptimizerUtils;
25+
import org.apache.sysds.hops.recompile.Recompiler;
26+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
2327
import org.apache.sysds.test.AutomatedTestBase;
2428
import org.apache.sysds.test.TestConfiguration;
2529
import org.apache.sysds.test.TestUtils;
30+
import org.junit.Assert;
31+
import org.junit.Ignore;
2632
import org.junit.Test;
2733

2834
public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -31,9 +37,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
3137
private static final String TEST_CLASS_DIR =
3238
TEST_DIR + RewriteSimplifyWeightedUnaryMMTest.class.getSimpleName() + "/";
3339

34-
private static final int rows = 100;
35-
private static final int cols = 100;
36-
//private static final double eps = Math.pow(10, -7);
40+
private static final int rows = 1123; //larger than blocksize needed
41+
private static final int cols = 1245;
3742

3843
@Override
3944
public void setUp() {
@@ -103,166 +108,28 @@ public void testWeightedUnaryMMScalarLeftRewrite(){
103108
testRewriteSimplifyWeightedUnaryMM(5, true); //pattern: 2*(W*(U%*%t(V)))
104109
}
105110

106-
/**
107-
* These tests cover the case for the third pattern
108-
* W * sop(U%*%t(V), c) or W * sop(U%*%t(V), c), where
109-
* sop stands for scalar operation (+, -, *, /) and c represents
110-
* some constant scalar.
111-
* */
112-
113-
@Test
114-
public void testWeightedUnaryMMAddLeftNoRewrite(){
115-
testRewriteSimplifyWeightedUnaryMM(6, false);
116-
}
117-
118-
@Test
119-
public void testWeightedUnaryMMAddLeftRewrite(){
120-
testRewriteSimplifyWeightedUnaryMM(6, true); //pattern: W * (c + U%*%t(V))
121-
}
122-
123-
@Test
124-
public void testWeightedUnaryMMMinusLeftNoRewrite(){
125-
testRewriteSimplifyWeightedUnaryMM(7, false);
126-
}
127-
128-
@Test
129-
public void testWeightedUnaryMMMinusLeftRewrite(){
130-
testRewriteSimplifyWeightedUnaryMM(7, true); //pattern: W * (c - U%*%t(V))
131-
}
132-
133111
@Test
134112
public void testWeightedUnaryMMMultLeftNoRewrite(){
135113
testRewriteSimplifyWeightedUnaryMM(8, false);
136114
}
137115

138116
@Test
117+
@Ignore //FIXME non-applied rewrite
139118
public void testWeightedUnaryMMMultLeftRewrite(){
140119
testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W * (c * (U%*%t(V)))
141120
}
142121

143-
@Test
144-
public void testWeightedUnaryMMDivLeftNoRewrite(){
145-
testRewriteSimplifyWeightedUnaryMM(9, false);
146-
}
147-
148-
@Test
149-
public void testWeightedUnaryMMDivLeftRewrite(){
150-
testRewriteSimplifyWeightedUnaryMM(9, true); //pattern: W * (c / (U%*%t(V)))
151-
}
152-
153-
// Same pattern but scalar from right instead of left
154-
155-
@Test
156-
public void testWeightedUnaryMMAddRightNoRewrite(){
157-
testRewriteSimplifyWeightedUnaryMM(10, false);
158-
}
159-
160-
@Test
161-
public void testWeightedUnaryMMAddRightRewrite(){
162-
testRewriteSimplifyWeightedUnaryMM(10, true); //pattern: W * (U%*%t(V) + c)
163-
}
164-
165-
@Test
166-
public void testWeightedUnaryMMMinusRightNoRewrite(){
167-
testRewriteSimplifyWeightedUnaryMM(11, false);
168-
}
169-
170-
@Test
171-
public void testWeightedUnaryMMMinusRightRewrite(){
172-
testRewriteSimplifyWeightedUnaryMM(11, true); //pattern: W * (U%*%t(V) - c)
173-
}
174-
175122
@Test
176123
public void testWeightedUnaryMMMulRightNoRewrite(){
177124
testRewriteSimplifyWeightedUnaryMM(12, false);
178125
}
179126

180127
@Test
128+
@Ignore //FIXME non-applied rewrite
181129
public void testWeightedUnaryMMMultRightRewrite(){
182130
testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W * ((U%*%t(V)) * c)
183131
}
184132

185-
@Test
186-
public void testWeightedUnaryMMDivRightNoRewrite(){
187-
testRewriteSimplifyWeightedUnaryMM(13, false);
188-
}
189-
190-
@Test
191-
public void testWeightedUnaryMMDivRightRewrite(){
192-
testRewriteSimplifyWeightedUnaryMM(13, true); //pattern: W * ((U%*%t(V)) / c)
193-
}
194-
195-
/**
196-
* Here, we omit the transpose in the dml script. The rewrite should catch the missing transpose
197-
* and replace V with t(V).
198-
**/
199-
200-
@Test
201-
public void testWeightedUnaryMMExpNoTranspose(){
202-
testRewriteSimplifyWeightedUnaryMM(14, true); //pattern: W * exp(U%*%V)
203-
}
204-
205-
@Test
206-
public void testWeightedUnaryMMAbsNoTranspose(){
207-
testRewriteSimplifyWeightedUnaryMM(15, true); //pattern: W * abs(U%*%V)
208-
}
209-
210-
@Test
211-
public void testWeightedUnaryMMSinNoTranspose(){
212-
testRewriteSimplifyWeightedUnaryMM(16, true); //pattern: W * sin(U%*%V)
213-
}
214-
215-
@Test
216-
public void testWeightedUnaryMMScalarRightNoTranspose(){
217-
testRewriteSimplifyWeightedUnaryMM(17, true); //pattern: (W*(U%*%V))*2
218-
}
219-
220-
@Test
221-
public void testWeightedUnaryMMScalarLeftNoTranspose(){
222-
testRewriteSimplifyWeightedUnaryMM(18, true); //pattern: 2*(W*(U%*%V))
223-
}
224-
225-
@Test
226-
public void testWeightedUnaryMMAddLeftNoTranspose(){
227-
testRewriteSimplifyWeightedUnaryMM(19, true); //pattern: W * (c + U%*%V)
228-
}
229-
230-
@Test
231-
public void testWeightedUnaryMMMinusLeftNoTranspose(){
232-
testRewriteSimplifyWeightedUnaryMM(20, true); //pattern: W * (c - U%*%V)
233-
}
234-
235-
@Test
236-
public void testWeightedUnaryMMMultLeftNoTranspose(){
237-
testRewriteSimplifyWeightedUnaryMM(21, true); //pattern: W * (c * (U%*%V))
238-
}
239-
240-
@Test
241-
public void testWeightedUnaryMMDivLeftNoTranspose(){
242-
testRewriteSimplifyWeightedUnaryMM(22, true); //pattern: W * (c / (U%*%V))
243-
}
244-
245-
@Test
246-
public void testWeightedUnaryMMAddRightNoTranspose(){
247-
testRewriteSimplifyWeightedUnaryMM(23, true); //pattern: W * (U%*%V + c)
248-
}
249-
250-
@Test
251-
public void testWeightedUnaryMMMinusRightNoTranspose(){
252-
testRewriteSimplifyWeightedUnaryMM(24, true); //pattern: W * (U%*%V - c)
253-
}
254-
255-
@Test
256-
public void testWeightedUnaryMMMultRightNoTranspose(){
257-
testRewriteSimplifyWeightedUnaryMM(25, true); //pattern: W * ((U%*%V) * c)
258-
}
259-
260-
@Test
261-
public void testWeightedUnaryMMDivRightNoTranspose(){
262-
testRewriteSimplifyWeightedUnaryMM(26, true); //pattern: W * ((U%*%V) / c)
263-
}
264-
265-
266133

267134
private void testRewriteSimplifyWeightedUnaryMM(int ID, boolean rewrites) {
268135
boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -280,11 +147,13 @@ private void testRewriteSimplifyWeightedUnaryMM(int ID, boolean rewrites) {
280147

281148
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
282149
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
150+
Recompiler.reinitRecompiler();
283151

284152
//create matrices
285-
double[][] U = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
286-
double[][] V = getRandomMatrix(rows, cols, -1, 1, 0.70d, 4);
287-
double[][] W = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
153+
int rank = 50;
154+
double[][] U = getRandomMatrix(rows, rank, -1, 1, 0.80d, 3);
155+
double[][] V = getRandomMatrix(cols, rank, -1, 1, 0.70d, 4);
156+
double[][] W = getRandomMatrix(rows, cols, -1, 1, 0.01d, 5);
288157
writeInputMatrixWithMTD("U", U, true);
289158
writeInputMatrixWithMTD("V", V, true);
290159
writeInputMatrixWithMTD("W", W, true);
@@ -293,15 +162,10 @@ private void testRewriteSimplifyWeightedUnaryMM(int ID, boolean rewrites) {
293162
runRScript(true);
294163

295164
//compare matrices
296-
// FIXME
297-
// HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
298-
// HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
299-
// TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
300-
// if(rewrites)
301-
// Assert.assertTrue(heavyHittersContainsString("wumm"));
302-
// else
303-
// Assert.assertFalse(heavyHittersContainsString("wumm"));
304-
165+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
166+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
167+
TestUtils.compareMatrices(dmlfile, rfile, 1e-8, "Stat-DML", "Stat-R");
168+
Assert.assertTrue(heavyHittersContainsString("wumm")==rewrites);
305169
}
306170
finally {
307171
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag1;

src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml

Lines changed: 8 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,83 +28,27 @@ c = 4.0
2828

2929
# Perform operations
3030
if(type == 1){
31-
R = W * exp(U%*%t(V))
31+
R = W * exp(U%*%t(V))
3232
}
3333
else if(type == 2){
34-
R = W * abs(U%*%t(V))
34+
R = W * abs(U%*%t(V))
3535
}
3636
else if(type == 3){
37-
R = W * sin(U%*%t(V))
37+
R = W * sin(U%*%t(V))
3838
}
3939
else if(type == 4){
40-
R = (W*(U%*%t(V)))*2
40+
R = (W*(U%*%t(V)))*2
4141
}
4242
else if(type == 5){
43-
R = 2*(W*(U%*%t(V)))
44-
}
45-
else if(type == 6){
46-
R = W * (c + U%*%t(V))
47-
}
48-
else if(type == 7){
49-
R = W * (c - U%*%t(V))
43+
R = 2*(W*(U%*%t(V)))
5044
}
5145
else if(type == 8){
52-
R = W * (c * (U%*%t(V)))
53-
}
54-
else if(type == 9){
55-
R = W * (c / (U%*%t(V)))
56-
}
57-
else if(type == 10){
58-
R = W * (U%*%t(V) + c)
59-
}
60-
else if(type == 11){
61-
R = W * (U%*%t(V) - c)
46+
R = W * (c * (U%*%t(V)))
6247
}
6348
else if(type == 12){
64-
R = W * ((U%*%t(V)) * c)
65-
}
66-
else if(type == 13){
67-
R = W * ((U%*%t(V)) / c)
68-
}
69-
else if(type == 14){
70-
R = W * exp(U%*%V)
71-
}
72-
else if(type == 15){
73-
R = W * abs(U%*%V)
74-
}
75-
else if(type == 16){
76-
R = W * sin(U%*%V)
77-
}
78-
else if(type == 17){
79-
R = (W*(U%*%V))*2
80-
}
81-
else if(type == 18){
82-
R = 2*(W*(U%*%V))
83-
}
84-
else if(type == 19){
85-
R = W * (c + U%*%V)
86-
}
87-
else if(type == 20){
88-
R = W * (c - U%*%V)
89-
}
90-
else if(type == 21){
91-
R = W * (c * (U%*%V))
92-
}
93-
else if(type == 22){
94-
R = W * (c / (U%*%V))
95-
}
96-
else if(type == 23){
97-
R = W * (U%*%V + c)
98-
}
99-
else if(type == 24){
100-
R = W * (U%*%V - c)
101-
}
102-
else if(type == 25){
103-
R = W * ((U%*%V) * c)
104-
}
105-
else if(type == 26){
106-
R = W * ((U%*%V) / c)
49+
R = W * ((U%*%t(V)) * c)
10750
}
10851

10952
# Write the result matrix R
11053
write(R, $5)
54+

0 commit comments

Comments
 (0)