Skip to content

Commit 55075f8

Browse files
committed
[SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases, part 2
1 parent 6652d43 commit 55075f8

File tree

4 files changed

+18
-36
lines changed

4 files changed

+18
-36
lines changed

src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(Str
126126
CPOperand in2 = new CPOperand(parts[2]);
127127
int pos = 3;
128128
boolean metaReturn = true;
129-
System.out.println(Arrays.toString(parts));
130129
if( parts.length == 7 ) //no need for meta data
131130
metaReturn = new CPOperand(parts[pos++]).getLiteral().getBooleanValue();
132131
outputs.add(new CPOperand(parts[pos], Types.ValueType.FP64, Types.DataType.MATRIX));

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.apache.sysds.test.TestConfiguration;
2929
import org.apache.sysds.test.TestUtils;
3030
import org.junit.Assert;
31-
import org.junit.Ignore;
3231
import org.junit.Test;
3332

3433
public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -60,7 +59,7 @@ public void testWeightedUnaryMMExpNoRewrite(){
6059

6160
@Test
6261
public void testWeightedUnaryMMExpRewrite(){
63-
testRewriteSimplifyWeightedUnaryMM(1, true); //pattern: W * exp(U%*%t(V))
62+
testRewriteSimplifyWeightedUnaryMM(1, true); //pattern: W * exp(U%*%t(V))
6463
}
6564

6665
@Test
@@ -70,7 +69,7 @@ public void testWeightedUnaryMMAbsNoRewrite(){
7069

7170
@Test
7271
public void testWeightedUnaryMMAbsRewrite(){
73-
testRewriteSimplifyWeightedUnaryMM(2, true); //pattern: W * abs(U%*%t(V))
72+
testRewriteSimplifyWeightedUnaryMM(2, true); //pattern: W * abs(U%*%t(V))
7473
}
7574

7675
@Test
@@ -80,7 +79,7 @@ public void testWeightedUnaryMMSinNoRewrite(){
8079

8180
@Test
8281
public void testWeightedUnaryMMSinRewrite(){
83-
testRewriteSimplifyWeightedUnaryMM(3, true); //pattern: W * sin(U%*%t(V))
82+
testRewriteSimplifyWeightedUnaryMM(3, true); //pattern: W * sin(U%*%t(V))
8483
}
8584

8685
/**
@@ -95,7 +94,7 @@ public void testWeightedUnaryMMScalarRightNoRewrite(){
9594

9695
@Test
9796
public void testWeightedUnaryMMScalarRightRewrite(){
98-
testRewriteSimplifyWeightedUnaryMM(4, true); //pattern: (W*(U%*%t(V)))*2
97+
testRewriteSimplifyWeightedUnaryMM(4, true); //pattern: (W*(U%*%t(V)))*2
9998
}
10099

101100
@Test
@@ -105,7 +104,7 @@ public void testWeightedUnaryMMScalarLeftNoRewrite(){
105104

106105
@Test
107106
public void testWeightedUnaryMMScalarLeftRewrite(){
108-
testRewriteSimplifyWeightedUnaryMM(5, true); //pattern: 2*(W*(U%*%t(V)))
107+
testRewriteSimplifyWeightedUnaryMM(5, true); //pattern: 2*(W*(U%*%t(V)))
109108
}
110109

111110
@Test
@@ -114,9 +113,8 @@ public void testWeightedUnaryMMMultLeftNoRewrite(){
114113
}
115114

116115
@Test
117-
@Ignore //FIXME non-applied rewrite
118116
public void testWeightedUnaryMMMultLeftRewrite(){
119-
testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W * (c * (U%*%t(V)))
117+
testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W * (2 * (U%*%t(V)))
120118
}
121119

122120
@Test
@@ -125,9 +123,8 @@ public void testWeightedUnaryMMMulRightNoRewrite(){
125123
}
126124

127125
@Test
128-
@Ignore //FIXME non-applied rewrite
129126
public void testWeightedUnaryMMMultRightRewrite(){
130-
testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W * ((U%*%t(V)) * c)
127+
testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W * ((U%*%t(V)) * 2)
131128
}
132129

133130

src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,22 @@ U = as.matrix(readMM(paste(args[1], "U.mtx", sep="")))
3333
V = as.matrix(readMM(paste(args[1], "V.mtx", sep="")))
3434
W = as.matrix(readMM(paste(args[1], "W.mtx", sep="")))
3535
type = as.integer(args[2])
36-
c = 4.0
3736

3837
# Perform operations
39-
if(type == 1 || type == 14){
38+
if(type == 1){
4039
R = W * exp(U%*%t(V))
41-
} else if(type == 2 || type == 15){
40+
} else if(type == 2){
4241
R = W * abs(U%*%t(V))
43-
} else if(type == 3 || type == 16){
42+
} else if(type == 3){
4443
R = W * sin(U%*%t(V))
45-
} else if(type == 4 || type == 17){
44+
} else if(type == 4){
4645
R = (W*(U%*%t(V)))*2
47-
} else if(type == 5 || type == 18){
46+
} else if(type == 5){
4847
R = 2*(W*(U%*%t(V)))
49-
} else if(type == 6 || type == 19){
50-
R = W * (c + U%*%t(V))
51-
} else if(type == 7 || type == 20){
52-
R = W * (c - U%*%t(V))
53-
} else if(type == 8 || type == 21){
54-
R = W * (c * (U%*%t(V)))
55-
} else if(type == 9 || type == 22){
56-
R = W * (c / (U%*%t(V)))
57-
} else if(type == 10 || type == 23){
58-
R = W * (U%*%t(V) + c)
59-
} else if(type == 11 || type == 24){
60-
R = W * (U%*%t(V) - c)
61-
} else if(type == 12 || type == 25){
62-
R = W * ((U%*%t(V)) * c)
63-
} else if(type == 13 || type == 26){
64-
R = W * ((U%*%t(V)) / c)
48+
} else if(type == 8){
49+
R = W * (2 * (U%*%t(V)))
50+
} else if(type == 12){
51+
R = W * ((U%*%t(V)) * 2)
6552
}
6653

6754
#Write result matrix R

src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ U = read($1)
2424
V = read($2)
2525
W = read($3)
2626
type = $4
27-
c = 4.0
2827

2928
# Perform operations
3029
if(type == 1){
@@ -43,10 +42,10 @@ else if(type == 5){
4342
R = 2*(W*(U%*%t(V)))
4443
}
4544
else if(type == 8){
46-
R = W * (c * (U%*%t(V)))
45+
R = W * (2 * (U%*%t(V)))
4746
}
4847
else if(type == 12){
49-
R = W * ((U%*%t(V)) * c)
48+
R = W * ((U%*%t(V)) * 2)
5049
}
5150

5251
# Write the result matrix R

0 commit comments

Comments
 (0)