2828import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
2929import org .apache .sysds .runtime .matrix .operators .ReorgOperator ;
3030import org .apache .sysds .test .TestUtils ;
31+ import org .apache .sysds .utils .stats .Timing ;
3132import org .junit .Test ;
3233import org .junit .runner .RunWith ;
3334import org .junit .runners .Parameterized ;
@@ -88,8 +89,17 @@ public static Collection<Object[]> data() {
8889 public void denseRollOperationSingleAndMultiThreadedShouldReturnSameResult () {
8990 int numThreads = getNumThreads ();
9091
92+ // Single-threaded timing
93+ Timing tSingle = new Timing (true );
9194 MatrixBlock outSingle = rollOperation (inputDense , 1 );
95+ double timeSingle = tSingle .stop ();
96+
97+ // Multithreaded timing
98+ Timing tMulti = new Timing (true );
9299 MatrixBlock outMulti = rollOperation (inputDense , numThreads );
100+ double timeMulti = tMulti .stop ();
101+
102+ logTiming ("Dense" , numThreads , timeSingle , timeMulti );
93103
94104 TestUtils .compareMatrices (outSingle , outMulti , 1e-12 ,
95105 "Dense Mismatch (numThreads=1 vs numThreads>1) for Size=" + rows + "x" + cols + " Shift=" + shift );
@@ -99,8 +109,17 @@ public void denseRollOperationSingleAndMultiThreadedShouldReturnSameResult() {
99109 public void sparseRollOperationSingleAndMultiThreadedShouldReturnSameResult () {
100110 int numThreads = getNumThreads ();
101111
112+ // Single-threaded timing
113+ Timing tSingle = new Timing (true );
102114 MatrixBlock outSingle = rollOperation (inputSparse , 1 );
115+ double timeSingle = tSingle .stop ();
116+
117+ // Multithreaded timing
118+ Timing tMulti = new Timing (true );
103119 MatrixBlock outMulti = rollOperation (inputSparse , numThreads );
120+ double timeMulti = tMulti .stop ();
121+
122+ logTiming ("Sparse" , numThreads , timeSingle , timeMulti );
104123
105124 TestUtils .compareMatrices (outSingle , outMulti , 1e-12 ,
106125 "Sparse Mismatch (numThreads=1 vs numThreads>1) for Size=" + rows + "x" + cols + " Shift=" + shift );
@@ -120,4 +139,14 @@ private static int getNumThreads() {
120139 int cores = Runtime .getRuntime ().availableProcessors ();
121140 return Math .max (2 , cores );
122141 }
142+
143+ private void logTiming (String type , int numThreads , double timeSingle , double timeMulti ) {
144+ double speedup = timeSingle / timeMulti ;
145+
146+ System .out .println ("\n --- " + type + " Roll Operation Timing for " + rows + "x" + cols + ", Shift=" + shift + " ---" );
147+ System .out .printf ("Single-threaded (1 core) took: %.3f ms\n " , timeSingle );
148+ System .out .printf ("Multithreaded (%d cores) took: %.3f ms\n " , numThreads , timeMulti );
149+ System .out .printf ("Speedup: %.2f\n " , speedup );
150+ System .out .println ("--------------------------------------------------------------------------------" );
151+ }
123152}
0 commit comments