Skip to content

Commit 70bca7f

Browse files
committed
[MINOR] Cleanup append tests for proper exec type handling
1 parent 0ec4dcd commit 70bca7f

File tree

4 files changed

+12
-46
lines changed

4 files changed

+12
-46
lines changed

src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import org.junit.Assert;
2525
import org.junit.Test;
26-
import org.apache.sysds.api.DMLScript;
2726
import org.apache.sysds.common.Types.ExecMode;
2827
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
2928
import org.apache.sysds.test.AutomatedTestBase;
@@ -104,15 +103,10 @@ public void testAppendChainMatrixSparseCP() {
104103
public void commonAppendTest(ExecMode platform, int rows, int cols1, int cols2, int cols3, boolean sparse)
105104
{
106105
TestConfiguration config = getAndLoadTestConfiguration(TEST_NAME);
107-
ExecMode prevPlfm=rtplatform;
108-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
106+
ExecMode prevPlfm=setExecMode(platform);
109107

110108
try
111109
{
112-
rtplatform = platform;
113-
if( rtplatform == ExecMode.SPARK )
114-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
115-
116110
config.addVariable("rows", rows);
117111
config.addVariable("cols", cols1);
118112

@@ -150,8 +144,7 @@ public void commonAppendTest(ExecMode platform, int rows, int cols1, int cols2,
150144
}
151145
}
152146
finally {
153-
rtplatform = prevPlfm;
154-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
147+
resetExecMode(prevPlfm);
155148
}
156149
}
157150
}

src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import org.junit.Assert;
2626
import org.junit.Test;
27-
import org.apache.sysds.api.DMLScript;
2827
import org.apache.sysds.common.Types.ExecMode;
2928
import org.apache.sysds.hops.BinaryOp;
3029
import org.apache.sysds.hops.OptimizerUtils;
@@ -146,18 +145,14 @@ public void testAppendOutBlock2SparseSP() {
146145
public void commonAppendTest(ExecMode platform, int rows, int cols1, int cols2, boolean sparse, AppendMethod forcedAppendMethod)
147146
{
148147
TestConfiguration config = getAndLoadTestConfiguration(TEST_NAME);
149-
ExecMode prevPlfm=rtplatform;
148+
ExecMode prevPlfm=setExecMode(platform);
150149
double sparsity = (sparse) ? sparsity2 : sparsity1;
151-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
152150

153151
try {
154152
if(forcedAppendMethod != null) {
155153
BinaryOp.FORCED_APPEND_METHOD = forcedAppendMethod;
156154
}
157-
rtplatform = platform;
158-
if( rtplatform == ExecMode.SPARK )
159-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
160-
155+
161156
config.addVariable("rows", rows);
162157
config.addVariable("cols", cols1);
163158

@@ -192,9 +187,7 @@ public void commonAppendTest(ExecMode platform, int rows, int cols1, int cols2,
192187
}
193188
}
194189
finally {
195-
//reset execution platform
196-
rtplatform = prevPlfm;
197-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
190+
resetExecMode(prevPlfm);
198191
BinaryOp.FORCED_APPEND_METHOD = null;
199192
}
200193
}

src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import org.junit.Assert;
2626
import org.junit.Test;
27-
import org.apache.sysds.api.DMLScript;
2827
import org.apache.sysds.common.Types.ExecMode;
2928
import org.apache.sysds.hops.OptimizerUtils;
3029
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -76,14 +75,8 @@ public void testAppendOutBlockCP() {
7675
public void commonAppendTest(ExecMode platform, int rows, int cols)
7776
{
7877
TestConfiguration config = getAndLoadTestConfiguration(TEST_NAME);
79-
80-
ExecMode prevPlfm=rtplatform;
78+
ExecMode prevPlfm = setExecMode(platform);
8179

82-
rtplatform = platform;
83-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
84-
if( rtplatform == ExecMode.SPARK )
85-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
86-
8780
try {
8881
config.addVariable("rows", rows);
8982
config.addVariable("cols", cols);
@@ -121,8 +114,7 @@ public void commonAppendTest(ExecMode platform, int rows, int cols)
121114
}
122115
}
123116
finally {
124-
rtplatform = prevPlfm;
125-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
117+
resetExecMode(prevPlfm);
126118
}
127119
}
128120
}

src/test/java/org/apache/sysds/test/functions/append/RBindCBindMatrixTest.java

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.apache.sysds.common.Opcodes;
2525
import org.junit.Assert;
2626
import org.junit.Test;
27-
import org.apache.sysds.api.DMLScript;
2827
import org.apache.sysds.common.Types.ExecMode;
2928
import org.apache.sysds.common.Types.ExecType;
3029
import org.apache.sysds.runtime.instructions.Instruction;
@@ -103,23 +102,15 @@ public void testCBindSparseSP() {
103102

104103
public void runRBindTest(String testname, boolean sparse, ExecType et)
105104
{
106-
ExecMode platformOld = rtplatform;
107-
switch( et ){
108-
case SPARK: rtplatform = ExecMode.SPARK; break;
109-
default: rtplatform = ExecMode.HYBRID; break;
110-
}
111-
112-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
113-
if( rtplatform == ExecMode.SPARK )
114-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
115-
105+
ExecMode platformOld = setExecMode(et);
106+
116107
String TEST_NAME = testname;
117108
TestConfiguration config = getTestConfiguration(TEST_NAME);
118109
loadTestConfiguration(config);
119110
double sparsity = (sparse) ? sparsity2 : sparsity1;
120111

121112
try
122-
{
113+
{
123114
String RI_HOME = SCRIPT_DIR + TEST_DIR;
124115
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
125116
//stats required for opcode checks
@@ -151,11 +142,8 @@ public void runRBindTest(String testname, boolean sparse, ExecType et)
151142
Assert.assertTrue("Rewrite not applied", !Statistics.getCPHeavyHitterOpCodes().contains(opcode) );
152143
}
153144
}
154-
finally
155-
{
156-
//reset execution platform
157-
rtplatform = platformOld;
158-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
145+
finally {
146+
resetExecMode(platformOld);
159147
}
160148
}
161149
}

0 commit comments

Comments
 (0)