1919
2020package org .apache .sysds .test .functions .builtin .part2 ;
2121
22+ import org .junit .Assert ;
2223import org .junit .Test ;
23- import org .apache .sysds .api .DMLScript ;
2424import org .apache .sysds .common .Types ;
2525import org .apache .sysds .common .Types .ExecType ;
26- import org .apache .sysds .hops .OptimizerUtils ;
2726import org .apache .sysds .test .AutomatedTestBase ;
2827import org .apache .sysds .test .TestConfiguration ;
2928import org .apache .sysds .test .TestUtils ;
@@ -50,58 +49,90 @@ public void setUp() {
5049
5150 @ Test
5251 public void testMultiLogRegInterceptCP0 () {
53- runMultiLogeRegTest ( 0 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
52+ runMultiLogeRegTest (0 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
5453 }
5554 @ Test
5655 public void testMultiLogRegInterceptCP1 () {
57- runMultiLogeRegTest ( 1 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
56+ runMultiLogeRegTest (1 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
5857 }
5958 @ Test
6059 public void testMultiLogRegInterceptCP2 () {
61- runMultiLogeRegTest ( 2 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
60+ runMultiLogeRegTest (2 , tol , 1.0 , maxIter , maxInnerIter , ExecType .CP );
6261 }
62+ @ Test
63+ public void testMultiLogRegBinInterceptCP0 () {
64+ runMultiLogeRegTest (0 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .CP );
65+ }
66+ @ Test
67+ public void testMultiLogRegBinInterceptCP1 () {
68+ runMultiLogeRegTest (1 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .CP );
69+ }
70+ @ Test
71+ public void testMultiLogRegBinInterceptCP2 () {
72+ runMultiLogeRegTest (2 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .CP );
73+ }
74+
6375 @ Test
6476 public void testMultiLogRegInterceptSpark0 () {
65- runMultiLogeRegTest ( 0 , tol , 1.0 , maxIter , maxInnerIter , ExecType .SPARK );
77+ runMultiLogeRegTest (0 , tol , 1.0 , maxIter , maxInnerIter , ExecType .SPARK );
6678 }
6779 @ Test
6880 public void testMultiLogRegInterceptSpark1 () {
69- runMultiLogeRegTest ( 1 , tol , 1.0 , maxIter , maxInnerIter , ExecType .SPARK );
81+ runMultiLogeRegTest (1 , tol , 1.0 , maxIter , maxInnerIter , ExecType .SPARK );
7082 }
7183 @ Test
7284 public void testMultiLogRegInterceptSpark2 () {
7385 runMultiLogeRegTest (2 , tol , 1.0 , maxIter , maxInnerIter , ExecType .SPARK );
7486 }
87+
88+ @ Test
89+ public void testMultiLogRegBinInterceptSpark0 () {
90+ runMultiLogeRegTest (0 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .SPARK );
91+ }
92+ @ Test
93+ public void testMultiLogRegBinInterceptSpark1 () {
94+ runMultiLogeRegTest (1 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .SPARK );
95+ }
96+ @ Test
97+ public void testMultiLogRegBinInterceptSpark2 () {
98+ runMultiLogeRegTest (2 , tol , 1.0 , maxIter , maxInnerIter , 2 , ExecType .SPARK );
99+ }
75100
76- private void runMultiLogeRegTest ( int inc , double tol , double reg , int maxOut , int maxIn , ExecType instType ) {
101+ private void runMultiLogeRegTest (int inc , double tol , double reg , int maxOut , int maxIn , ExecType instType ) {
102+ runMultiLogeRegTest (inc , tol , reg , maxOut , maxIn , 6 , instType );
103+ }
104+
105+ private void runMultiLogeRegTest (int inc , double tol , double reg ,
106+ int maxOut , int maxIn , int numClasses , ExecType instType )
107+ {
77108 Types .ExecMode platformOld = setExecMode (instType );
78109
79- boolean oldFlag = OptimizerUtils .ALLOW_ALGEBRAIC_SIMPLIFICATION ;
80- boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
81-
82110 try {
83111 loadTestConfiguration (getTestConfiguration (TEST_NAME ));
84112
85113 String HOME = SCRIPT_DIR + TEST_DIR ;
86114 fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
87115
88- programArgs = new String []{"-nvargs" , "X=" + input ("X" ), "Y=" + input ("Y" ), "output=" + output ("betas" ),
89- "inc=" + String .valueOf (inc ).toUpperCase (), "tol=" + tol , "reg=" + reg , "maxOut=" + maxOut , "maxIn=" +maxIn , "verbose=FALSE" };
116+ programArgs = new String []{"-stats" ,"-nvargs" ,
117+ "X=" + input ("X" ), "Y=" + input ("Y" ), "output=" + output ("betas" ),
118+ "inc=" + String .valueOf (inc ).toUpperCase (), "tol=" + tol ,
119+ "reg=" + reg , "maxOut=" + maxOut , "maxIn=" +maxIn , "verbose=FALSE" };
90120
91121 double [][] X = getRandomMatrix (rows , colsX , 0 , 1 , sparse , -1 );
92- double [][] Y = getRandomMatrix (rows , 1 , 0 , 5 , 1 , -1 );
122+ double [][] Y = getRandomMatrix (rows , 1 , 0 , numClasses - 1 , 1 , -1 );
93123 Y = TestUtils .round (Y );
94124
95125 writeInputMatrixWithMTD ("X" , X , true );
96126 writeInputMatrixWithMTD ("Y" , Y , true );
97127 runTest (true , false , null , -1 );
128+
129+ if (numClasses == 2 ) {
130+ String opcode = instType ==ExecType .SPARK ? "sp_mapmmchain" : "mmchain" ;
131+ Assert .assertTrue (heavyHittersContainsString (opcode ));
132+ }
98133 }
99134 finally {
100- rtplatform = platformOld ;
101- DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
102- OptimizerUtils .ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag ;
103- OptimizerUtils .ALLOW_AUTO_VECTORIZATION = true ;
104- OptimizerUtils .ALLOW_OPERATOR_FUSION = true ;
135+ resetExecMode (platformOld );
105136 }
106137 }
107138}
0 commit comments