Skip to content

Commit 78b23cf

Browse files
committed
[SYSTEMDS-3841] Additional multiLogReg tests to check mmchain rewrite
As it turns out, there was no bug causing mmchain not being applied for the builtin function multiLogReg, but we only apply this rewrite for binary classification not multi-class classification (where only codegen is capable of doing so). We now added additional tests for both binary/multi-class and the check for correctly applied mmchain.
1 parent 51b53c3 commit 78b23cf

File tree

1 file changed

+50
-19
lines changed

1 file changed

+50
-19
lines changed

src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMultiLogRegTest.java

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919

2020
package org.apache.sysds.test.functions.builtin.part2;
2121

22+
import org.junit.Assert;
2223
import org.junit.Test;
23-
import org.apache.sysds.api.DMLScript;
2424
import org.apache.sysds.common.Types;
2525
import org.apache.sysds.common.Types.ExecType;
26-
import org.apache.sysds.hops.OptimizerUtils;
2726
import org.apache.sysds.test.AutomatedTestBase;
2827
import org.apache.sysds.test.TestConfiguration;
2928
import org.apache.sysds.test.TestUtils;
@@ -50,58 +49,90 @@ public void setUp() {
5049

5150
@Test
5251
public void testMultiLogRegInterceptCP0() {
53-
runMultiLogeRegTest( 0, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
52+
runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
5453
}
5554
@Test
5655
public void testMultiLogRegInterceptCP1() {
57-
runMultiLogeRegTest( 1, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
56+
runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
5857
}
5958
@Test
6059
public void testMultiLogRegInterceptCP2() {
61-
runMultiLogeRegTest( 2, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
60+
runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, ExecType.CP);
6261
}
62+
@Test
63+
public void testMultiLogRegBinInterceptCP0() {
64+
runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.CP);
65+
}
66+
@Test
67+
public void testMultiLogRegBinInterceptCP1() {
68+
runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.CP);
69+
}
70+
@Test
71+
public void testMultiLogRegBinInterceptCP2() {
72+
runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.CP);
73+
}
74+
6375
@Test
6476
public void testMultiLogRegInterceptSpark0() {
65-
runMultiLogeRegTest( 0, tol, 1.0, maxIter, maxInnerIter, ExecType.SPARK);
77+
runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, ExecType.SPARK);
6678
}
6779
@Test
6880
public void testMultiLogRegInterceptSpark1() {
69-
runMultiLogeRegTest( 1, tol, 1.0, maxIter, maxInnerIter, ExecType.SPARK);
81+
runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, ExecType.SPARK);
7082
}
7183
@Test
7284
public void testMultiLogRegInterceptSpark2() {
7385
runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, ExecType.SPARK);
7486
}
87+
88+
@Test
89+
public void testMultiLogRegBinInterceptSpark0() {
90+
runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.SPARK);
91+
}
92+
@Test
93+
public void testMultiLogRegBinInterceptSpark1() {
94+
runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.SPARK);
95+
}
96+
@Test
97+
public void testMultiLogRegBinInterceptSpark2() {
98+
runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, 2, ExecType.SPARK);
99+
}
75100

76-
private void runMultiLogeRegTest( int inc, double tol, double reg, int maxOut, int maxIn, ExecType instType) {
101+
private void runMultiLogeRegTest(int inc, double tol, double reg, int maxOut, int maxIn, ExecType instType) {
102+
runMultiLogeRegTest(inc, tol, reg, maxOut, maxIn, 6, instType);
103+
}
104+
105+
private void runMultiLogeRegTest(int inc, double tol, double reg,
106+
int maxOut, int maxIn, int numClasses, ExecType instType)
107+
{
77108
Types.ExecMode platformOld = setExecMode(instType);
78109

79-
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
80-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
81-
82110
try {
83111
loadTestConfiguration(getTestConfiguration(TEST_NAME));
84112

85113
String HOME = SCRIPT_DIR + TEST_DIR;
86114
fullDMLScriptName = HOME + TEST_NAME + ".dml";
87115

88-
programArgs = new String[]{"-nvargs", "X=" + input("X"), "Y=" + input("Y"), "output=" + output("betas"),
89-
"inc=" + String.valueOf(inc).toUpperCase(), "tol=" + tol, "reg=" + reg, "maxOut=" + maxOut, "maxIn="+maxIn, "verbose=FALSE"};
116+
programArgs = new String[]{"-stats","-nvargs",
117+
"X=" + input("X"), "Y=" + input("Y"), "output=" + output("betas"),
118+
"inc=" + String.valueOf(inc).toUpperCase(), "tol=" + tol,
119+
"reg=" + reg, "maxOut=" + maxOut, "maxIn="+maxIn, "verbose=FALSE"};
90120

91121
double[][] X = getRandomMatrix(rows, colsX, 0, 1, sparse, -1);
92-
double[][] Y = getRandomMatrix(rows, 1, 0, 5, 1, -1);
122+
double[][] Y = getRandomMatrix(rows, 1, 0, numClasses-1, 1, -1);
93123
Y = TestUtils.round(Y);
94124

95125
writeInputMatrixWithMTD("X", X, true);
96126
writeInputMatrixWithMTD("Y", Y, true);
97127
runTest(true, false, null, -1);
128+
129+
if(numClasses == 2) {
130+
String opcode = instType==ExecType.SPARK ? "sp_mapmmchain" : "mmchain";
131+
Assert.assertTrue(heavyHittersContainsString(opcode));
132+
}
98133
}
99134
finally {
100-
rtplatform = platformOld;
101-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
102-
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
103-
OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
104-
OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
135+
resetExecMode(platformOld);
105136
}
106137
}
107138
}

0 commit comments

Comments
 (0)