Skip to content

Commit 9b6a96d

Browse files
committed
[SYSTEMDS-3783] Fix wsigmoid rewrite test setup
The recent addition of various rewrite tests for code coverage left a FIXME on the wsigmoid test which gave incorrect results for all variants without transpose. After double checking, it turns out the test setup was wrong in the assumptions when the rewrite should apply (missing transpose) and how the shapes of involved matrices look like.
1 parent 29b3c61 commit 9b6a96d

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
package org.apache.sysds.test.functions.rewrite;
2121

22+
import java.util.HashMap;
23+
2224
import org.apache.sysds.hops.OptimizerUtils;
25+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
2326
import org.apache.sysds.test.AutomatedTestBase;
2427
import org.apache.sysds.test.TestConfiguration;
2528
import org.apache.sysds.test.TestUtils;
@@ -32,9 +35,8 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest extends AutomatedTestBas
3235
private static final String TEST_CLASS_DIR =
3336
TEST_DIR + RewriteSimplifyWeightedSigmoidMMChainsTest.class.getSimpleName() + "/";
3437

35-
private static final int rows = 100;
38+
private static final int rows = 150;
3639
private static final int cols = 100;
37-
//private static final double eps = Math.pow(10, -10);
3840

3941
@Override
4042
public void setUp() {
@@ -125,8 +127,9 @@ private void testRewriteSimplifyWeightedSigmoidMMChains(int ID, boolean rewrites
125127
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
126128

127129
//create matrices
128-
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
129-
double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.70d, 4);
130+
int rank = 50;
131+
double[][] X = getRandomMatrix(cols, rank, -1, 1, 0.80d, 3);
132+
double[][] Y = getRandomMatrix(rows, rank, -1, 1, 0.70d, 4);
130133
double[][] W = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
131134
writeInputMatrixWithMTD("X", X, true);
132135
writeInputMatrixWithMTD("Y", Y, true);
@@ -136,10 +139,9 @@ private void testRewriteSimplifyWeightedSigmoidMMChains(int ID, boolean rewrites
136139
runRScript(true);
137140

138141
//compare matrices
139-
// FIXME
140-
// HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
141-
// HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
142-
// compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
142+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
143+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
144+
TestUtils.compareMatrices(dmlfile, rfile, 1e-8, "Stat-DML", "Stat-R");
143145

144146
if(rewrites)
145147
Assert.assertTrue(heavyHittersContainsString("wsigmoid"));

src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Y = read($2)
2525
W = read($3)
2626
type = $4
2727

28+
if( type > 4 )
29+
X = t(X);
2830

2931
# Perform operations
3032
if(type == 1){

0 commit comments

Comments
 (0)