2222import java .util .HashMap ;
2323
2424import org .apache .sysds .common .Opcodes ;
25- import org .apache .sysds .utils .stats .InfrastructureAnalyzer ;
2625import org .junit .Assert ;
2726import org .junit .Test ;
28- import org .apache .sysds .api .DMLScript ;
2927import org .apache .sysds .common .Types .ExecMode ;
3028import org .apache .sysds .common .Types .ExecType ;
3129import org .apache .sysds .runtime .instructions .Instruction ;
@@ -44,18 +42,16 @@ public class FullReverseTest extends AutomatedTestBase
4442 private final static String TEST_DIR = "functions/reorg/" ;
4543 private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest .class .getSimpleName () + "/" ;
4644
47- private final static int rows1 = 2017 ;
48- private final static int cols1 = 1001 ;
45+ //single-threaded execution
46+ private final static int rows1 = 201 ;
47+ private final static int cols1 = 100 ;
48+ //multi-threaded / distributed execution
49+ private final static int rows2 = 2017 ;
50+ private final static int cols2 = 1001 ;
51+
4952 private final static double sparsity1 = 0.7 ;
5053 private final static double sparsity2 = 0.1 ;
5154
52- // Multi-threading test parameters
53- private final static int rows_mt = 5018 ; // Larger for multi-threading benefits
54- private final static int cols_mt = 1001 ; // Larger for multi-threading benefits
55- private final static int [] threadCounts = {1 , 2 , 4 , 8 };
56- // Set global parallelism for SystemDS to enable multi-threading
57- private final static int oldPar = InfrastructureAnalyzer .getLocalParallelism ();
58-
5955 @ Override
6056 public void setUp () {
6157 TestUtils .clearAssertionInformation ();
@@ -65,97 +61,74 @@ public void setUp() {
6561
6662 @ Test
6763 public void testReverseVectorDenseCP () {
68- runReverseTest (TEST_NAME1 , false , false , ExecType .CP );
64+ runReverseTest (TEST_NAME1 , false , rows1 , 1 , ExecType .CP );
6965 }
7066
7167 @ Test
7268 public void testReverseVectorSparseCP () {
73- runReverseTest (TEST_NAME1 , false , true , ExecType .CP );
69+ runReverseTest (TEST_NAME1 , true , rows1 , 1 , ExecType .CP );
7470 }
7571
7672 @ Test
7773 public void testReverseVectorDenseCPMultiThread () {
78- runReverseTestMultiThread (TEST_NAME1 , false , false , ExecType .CP );
74+ runReverseTest (TEST_NAME1 , false , rows2 , 1 , ExecType .CP );
7975 }
8076
8177 @ Test
8278 public void testReverseVectorSparseCPMultiThread () {
83- runReverseTestMultiThread (TEST_NAME1 , false , true , ExecType .CP );
84- }
85-
86- @ Test
87- public void testReverseVectorDenseSPMultiThread () {
88- runReverseTestMultiThread (TEST_NAME1 , false , false , ExecType .SPARK );
79+ runReverseTest (TEST_NAME1 , true , rows2 , 1 , ExecType .CP );
8980 }
9081
9182 @ Test
9283 public void testReverseVectorDenseSP () {
93- runReverseTest (TEST_NAME1 , false , false , ExecType .SPARK );
84+ runReverseTest (TEST_NAME1 , false , rows2 , 1 , ExecType .SPARK );
9485 }
9586
9687 @ Test
9788 public void testReverseVectorSparseSP () {
98- runReverseTest (TEST_NAME1 , false , true , ExecType .SPARK );
89+ runReverseTest (TEST_NAME1 , true , rows2 , 1 , ExecType .SPARK );
9990 }
10091
10192 @ Test
10293 public void testReverseMatrixDenseCP () {
103- runReverseTest (TEST_NAME1 , true , false , ExecType .CP );
94+ runReverseTest (TEST_NAME1 , false , rows1 , cols1 , ExecType .CP );
10495 }
10596
10697 @ Test
10798 public void testReverseMatrixSparseCP () {
108- runReverseTest (TEST_NAME1 , true , true , ExecType .CP );
99+ runReverseTest (TEST_NAME1 , true , rows1 , cols1 , ExecType .CP );
109100 }
110101
111102 @ Test
112103 public void testReverseMatrixDenseSP () {
113- runReverseTest (TEST_NAME1 , true , false , ExecType .SPARK );
104+ runReverseTest (TEST_NAME1 , false , rows2 , cols2 , ExecType .SPARK );
114105 }
115106
116107 @ Test
117108 public void testReverseMatrixSparseSP () {
118- runReverseTest (TEST_NAME1 , true , true , ExecType .SPARK );
109+ runReverseTest (TEST_NAME1 , true , rows2 , cols2 , ExecType .SPARK );
119110 }
120111
121112 @ Test
122113 public void testReverseVectorDenseRewriteCP () {
123- runReverseTest (TEST_NAME2 , false , false , ExecType .CP );
114+ runReverseTest (TEST_NAME2 , false , rows1 , 1 , ExecType .CP );
124115 }
125116
126117 @ Test
127118 public void testReverseMatrixDenseRewriteCP () {
128- runReverseTest (TEST_NAME2 , true , false , ExecType .CP );
129- }
130-
119+ runReverseTest (TEST_NAME2 , false , rows1 , 1 , ExecType .CP );
120+ }
131121
132- /**
133- *
134- * @param sparseM1
135- * @param sparseM2
136- * @param instType
137- */
138- private void runReverseTest (String testname , boolean matrix , boolean sparse , ExecType instType )
122+ private void runReverseTest (String testname , boolean sparse , int rows , int cols , ExecType instType )
139123 {
140- //rtplatform for MR
141- ExecMode platformOld = rtplatform ;
142- switch ( instType ){
143- case SPARK : rtplatform = ExecMode .SPARK ; break ;
144- default : rtplatform = ExecMode .HYBRID ; break ;
145- }
146- boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
147- if ( rtplatform == ExecMode .SPARK )
148- DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
149-
124+ ExecMode platformOld = setExecMode (instType );
150125 String TEST_NAME = testname ;
151126
152127 try
153128 {
154- int cols = matrix ? cols1 : 1 ;
155129 double sparsity = sparse ? sparsity2 : sparsity1 ;
156130 getAndLoadTestConfiguration (TEST_NAME );
157131
158- /* This is for running the junit test the new way, i.e., construct the arguments directly */
159132 String HOME = SCRIPT_DIR + TEST_DIR ;
160133 fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
161134 programArgs = new String []{"-stats" ,"-explain" ,"-args" , input ("A" ), output ("B" ) };
@@ -164,10 +137,10 @@ private void runReverseTest(String testname, boolean matrix, boolean sparse, Exe
164137 rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir () + " " + expectedDir ();
165138
166139 //generate actual dataset
167- double [][] A = getRandomMatrix (rows1 , cols , -1 , 1 , sparsity , 7 );
140+ double [][] A = getRandomMatrix (rows , cols , -1 , 1 , sparsity , 7 );
168141 writeInputMatrixWithMTD ("A" , A , true );
169142
170- runTest (true , false , null , -1 );
143+ runTest (true , false , null , -1 );
171144 runRScript (true );
172145
173146 //compare matrices
@@ -181,85 +154,8 @@ private void runReverseTest(String testname, boolean matrix, boolean sparse, Exe
181154 else if ( instType == ExecType .SPARK )
182155 Assert .assertTrue ("Missing opcode: " +Instruction .SP_INST_PREFIX +Opcodes .REV .toString (), Statistics .getCPHeavyHitterOpCodes ().contains (Instruction .SP_INST_PREFIX +Opcodes .REV ));
183156 }
184- finally
185- {
186- //reset flags
187- rtplatform = platformOld ;
188- DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
189- }
190- }
191-
192- private void runReverseTestMultiThread (String testname , boolean matrix , boolean sparse , ExecType instType )
193- {
194- // Compare single-thread vs multi-thread results
195- // HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1);
196- HashMap <CellIndex , Double > mtResult = runReverseWithThreads (testname , matrix , sparse , instType , 8 );
197-
198- // Compare results to ensure consistency
199- // TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result");
200- }
201-
202- private HashMap <CellIndex , Double > runReverseWithThreads (String testname , boolean matrix , boolean sparse , ExecType instType , int numThreads )
203- {
204- //rtplatform for MR
205- ExecMode platformOld = rtplatform ;
206- switch ( instType ){
207- case SPARK : rtplatform = ExecMode .SPARK ; break ;
208- default : rtplatform = ExecMode .HYBRID ; break ;
209- }
210- boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
211- if ( rtplatform == ExecMode .SPARK )
212- DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
213-
214- String TEST_NAME = testname ;
215-
216- System .out .println ("I am trying to run multi-thread" );
217-
218- try
219- {
220- System .setProperty ("sysds.parallel.threads" , String .valueOf (numThreads ));
221-
222- // int cols = matrix ? cols_mt : 1;
223- double sparsity = sparse ? sparsity2 : sparsity1 ;
224- getAndLoadTestConfiguration (TEST_NAME );
225-
226- /* This is for running the junit test the new way, i.e., construct the arguments directly */
227- String HOME = SCRIPT_DIR + TEST_DIR ;
228- fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
229-
230- // Add thread count to program arguments
231- programArgs = new String []{"-stats" ,"-explain" ,"-args" , input ("A" ), output ("B" ) };
232-
233- fullRScriptName = HOME + TEST_NAME + ".R" ;
234- rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir () + " " + expectedDir ();
235-
236- //generate actual dataset
237- double [][] A = getRandomMatrix (rows_mt , cols_mt , -1 , 1 , sparsity , 7 );
238- writeInputMatrixWithMTD ("A" , A , true );
239-
240- // Run with specified thread count (this is the key part)
241- runTest (true , false , null , -1 );
242-
243- //read and return results
244- HashMap <CellIndex , Double > dmlfile = readDMLMatrixFromOutputDir ("B" );
245-
246- //check generated opcode
247- if ( instType == ExecType .CP )
248- Assert .assertTrue ("Missing opcode: rev" , Statistics .getCPHeavyHitterOpCodes ().contains (Opcodes .REV .toString ()));
249- else if ( instType == ExecType .SPARK )
250- Assert .assertTrue ("Missing opcode: " +Instruction .SP_INST_PREFIX +Opcodes .REV .toString (), Statistics .getCPHeavyHitterOpCodes ().contains (Instruction .SP_INST_PREFIX +Opcodes .REV ));
251-
252- return dmlfile ;
253- }
254- catch (Exception ex ) {
255- throw new RuntimeException (ex );
256- }
257157 finally {
258- //reset flags
259- rtplatform = platformOld ;
260- DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
261- System .setProperty ("sysds.parallel.threads" , String .valueOf (oldPar ));
158+ resetExecMode (platformOld );
262159 }
263160 }
264-
265- }
161+ }
0 commit comments