2222import java .util .HashMap ;
2323
2424import org .apache .sysds .common .Opcodes ;
25+ import org .apache .sysds .utils .stats .InfrastructureAnalyzer ;
2526import org .junit .Assert ;
2627import org .junit .Test ;
2728import org .apache .sysds .api .DMLScript ;
@@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase
4445 private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest .class .getSimpleName () + "/" ;
4546
4647 private final static int rows1 = 2017 ;
47- private final static int cols1 = 1001 ;
48+ private final static int cols1 = 1001 ;
4849 private final static double sparsity1 = 0.7 ;
4950 private final static double sparsity2 = 0.1 ;
5051
52+ // Multi-threading test parameters
53+ private final static int rows_mt = 5018 ; // Larger for multi-threading benefits
54+ private final static int cols_mt = 1001 ; // Larger for multi-threading benefits
55+ private final static int [] threadCounts = {1 , 2 , 4 , 8 };
56+ // Set global parallelism for SystemDS to enable multi-threading
57+ private final static int oldPar = InfrastructureAnalyzer .getLocalParallelism ();
58+
5159 @ Override
5260 public void setUp () {
5361 TestUtils .clearAssertionInformation ();
@@ -64,7 +72,22 @@ public void testReverseVectorDenseCP() {
6472 public void testReverseVectorSparseCP () {
6573 runReverseTest (TEST_NAME1 , false , true , ExecType .CP );
6674 }
67-
75+
76+ @ Test
77+ public void testReverseVectorDenseCPMultiThread () {
78+ runReverseTestMultiThread (TEST_NAME1 , false , false , ExecType .CP );
79+ }
80+
81+ @ Test
82+ public void testReverseVectorSparseCPMultiThread () {
83+ runReverseTestMultiThread (TEST_NAME1 , false , true , ExecType .CP );
84+ }
85+
86+ @ Test
87+ public void testReverseVectorDenseSPMultiThread () {
88+ runReverseTestMultiThread (TEST_NAME1 , false , false , ExecType .SPARK );
89+ }
90+
6891 @ Test
6992 public void testReverseVectorDenseSP () {
7093 runReverseTest (TEST_NAME1 , false , false , ExecType .SPARK );
@@ -165,6 +188,78 @@ else if ( instType == ExecType.SPARK )
165188 DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
166189 }
167190 }
168-
191+
192+ private void runReverseTestMultiThread (String testname , boolean matrix , boolean sparse , ExecType instType )
193+ {
194+ // Compare single-thread vs multi-thread results
195+ // HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1);
196+ HashMap <CellIndex , Double > mtResult = runReverseWithThreads (testname , matrix , sparse , instType , 8 );
197+
198+ // Compare results to ensure consistency
199+ // TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result");
200+ }
201+
202+ private HashMap <CellIndex , Double > runReverseWithThreads (String testname , boolean matrix , boolean sparse , ExecType instType , int numThreads )
203+ {
204+ //rtplatform for MR
205+ ExecMode platformOld = rtplatform ;
206+ switch ( instType ){
207+ case SPARK : rtplatform = ExecMode .SPARK ; break ;
208+ default : rtplatform = ExecMode .HYBRID ; break ;
209+ }
210+ boolean sparkConfigOld = DMLScript .USE_LOCAL_SPARK_CONFIG ;
211+ if ( rtplatform == ExecMode .SPARK )
212+ DMLScript .USE_LOCAL_SPARK_CONFIG = true ;
213+
214+ String TEST_NAME = testname ;
215+
216+ System .out .println ("I am trying to run multi-thread" );
217+
218+ try
219+ {
220+ System .setProperty ("sysds.parallel.threads" , String .valueOf (numThreads ));
221+
222+ // int cols = matrix ? cols_mt : 1;
223+ double sparsity = sparse ? sparsity2 : sparsity1 ;
224+ getAndLoadTestConfiguration (TEST_NAME );
225+
226+ /* This is for running the junit test the new way, i.e., construct the arguments directly */
227+ String HOME = SCRIPT_DIR + TEST_DIR ;
228+ fullDMLScriptName = HOME + TEST_NAME + ".dml" ;
229+
230+ // Add thread count to program arguments
231+ programArgs = new String []{"-stats" ,"-explain" ,"-args" , input ("A" ), output ("B" ) };
232+
233+ fullRScriptName = HOME + TEST_NAME + ".R" ;
234+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir () + " " + expectedDir ();
235+
236+ //generate actual dataset
237+ double [][] A = getRandomMatrix (rows_mt , cols_mt , -1 , 1 , sparsity , 7 );
238+ writeInputMatrixWithMTD ("A" , A , true );
239+
240+ // Run with specified thread count (this is the key part)
241+ runTest (true , false , null , -1 );
242+
243+ //read and return results
244+ HashMap <CellIndex , Double > dmlfile = readDMLMatrixFromOutputDir ("B" );
245+
246+ //check generated opcode
247+ if ( instType == ExecType .CP )
248+ Assert .assertTrue ("Missing opcode: rev" , Statistics .getCPHeavyHitterOpCodes ().contains (Opcodes .REV .toString ()));
249+ else if ( instType == ExecType .SPARK )
250+ Assert .assertTrue ("Missing opcode: " +Instruction .SP_INST_PREFIX +Opcodes .REV .toString (), Statistics .getCPHeavyHitterOpCodes ().contains (Instruction .SP_INST_PREFIX +Opcodes .REV ));
251+
252+ return dmlfile ;
253+ }
254+ catch (Exception ex ) {
255+ throw new RuntimeException (ex );
256+ }
257+ finally {
258+ //reset flags
259+ rtplatform = platformOld ;
260+ DMLScript .USE_LOCAL_SPARK_CONFIG = sparkConfigOld ;
261+ System .setProperty ("sysds.parallel.threads" , String .valueOf (oldPar ));
262+ }
263+ }
169264
170265}
0 commit comments