3434import org .apache .sysds .runtime .matrix .data .LibMatrixReorg ;
3535import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
3636import org .apache .sysds .runtime .matrix .operators .BinaryOperator ;
37+ import org .apache .sysds .utils .stats .Timing ;
3738
3839/**
3940 * Support compressed MM chain operation to fuse the following cases :
5354public final class CLALibMMChain {
5455 static final Log LOG = LogFactory .getLog (CLALibMMChain .class .getName ());
5556
57+ /** Reusable cache intermediate double array for temporary decompression */
58+ private static ThreadLocal <double []> cacheIntermediate = null ;
59+
5660 private CLALibMMChain () {
5761 // private constructor
5862 }
@@ -87,20 +91,31 @@ private CLALibMMChain() {
8791 public static MatrixBlock mmChain (CompressedMatrixBlock x , MatrixBlock v , MatrixBlock w , MatrixBlock out ,
8892 ChainType ctype , int k ) {
8993
94+ Timing t = new Timing ();
9095 if (x .isEmpty ())
9196 return returnEmpty (x , out );
9297
9398 // Morph the columns to efficient types for the operation.
9499 x = filterColGroups (x );
100+ double preFilterTime = t .stop ();
95101
96102 // Allow overlapping intermediate if the intermediate is guaranteed not to be overlapping.
97103 final boolean allowOverlap = x .getColGroups ().size () == 1 && isOverlappingAllowed ();
98104
99105 // Right hand side multiplication
100- MatrixBlock tmp = CLALibRightMultBy .rightMultByMatrix (x , v , null , k , allowOverlap );
106+ MatrixBlock tmp = CLALibRightMultBy .rightMultByMatrix (x , v , null , k , true );
107+
108+ double rmmTime = t .stop ();
101109
102- if (ctype == ChainType .XtwXv ) // Multiply intermediate with vector if needed
110+ if (ctype == ChainType .XtwXv ) { // Multiply intermediate with vector if needed
103111 tmp = binaryMultW (tmp , w , k );
112+ }
113+
114+ if (!allowOverlap && tmp instanceof CompressedMatrixBlock ) {
115+ tmp = decompressIntermediate ((CompressedMatrixBlock ) tmp , k );
116+ }
117+
118+ double decompressTime = t .stop ();
104119
105120 if (tmp instanceof CompressedMatrixBlock )
106121 // Compressed Compressed Matrix Multiplication
@@ -109,12 +124,50 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
109124 // LMM with Compressed - uncompressed multiplication.
110125 CLALibLeftMultBy .leftMultByMatrixTransposed (x , tmp , out , k );
111126
127+ double lmmTime = t .stop ();
112128 if (out .getNumColumns () != 1 ) // transpose the output to make it a row output if needed
113129 out = LibMatrixReorg .transposeInPlace (out , k );
114130
131+ if (LOG .isDebugEnabled ()) {
132+ StringBuilder sb = new StringBuilder ("\n " );
133+ sb .append ("\n PreFilter Time : " + preFilterTime );
134+ sb .append ("\n Chain RMM : " + rmmTime );
135+ sb .append ("\n Chain RMM Decompress: " + decompressTime );
136+ sb .append ("\n Chain LMM : " + lmmTime );
137+ sb .append ("\n Chain Transpose : " + t .stop ());
138+ LOG .debug (sb .toString ());
139+ }
140+
115141 return out ;
116142 }
117143
144+ private static MatrixBlock decompressIntermediate (CompressedMatrixBlock tmp , int k ) {
145+ // cacheIntermediate
146+ final int rows = tmp .getNumRows ();
147+ final int cols = tmp .getNumColumns ();
148+ final int nCells = rows * cols ;
149+ final double [] tmpArr ;
150+ if (cacheIntermediate == null ) {
151+ tmpArr = new double [nCells ];
152+ cacheIntermediate = new ThreadLocal <>();
153+ cacheIntermediate .set (tmpArr );
154+ }
155+ else {
156+ double [] cachedArr = cacheIntermediate .get ();
157+ if (cachedArr == null || cachedArr .length < nCells ) {
158+ tmpArr = new double [nCells ];
159+ cacheIntermediate .set (tmpArr );
160+ }
161+ else {
162+ tmpArr = cachedArr ;
163+ }
164+ }
165+
166+ final MatrixBlock tmpV = new MatrixBlock (tmp .getNumRows (), tmp .getNumColumns (), tmpArr );
167+ CLALibDecompress .decompressTo ((CompressedMatrixBlock ) tmp , tmpV , 0 , 0 , k , false , true );
168+ return tmpV ;
169+ }
170+
118171 private static boolean isOverlappingAllowed () {
119172 return ConfigurationManager .getDMLConfig ().getBooleanValue (DMLConfig .COMPRESSED_OVERLAPPING );
120173 }
@@ -146,6 +199,8 @@ private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) {
146199 final List <AColGroup > groups = x .getColGroups ();
147200 final boolean shouldFilter = CLALibUtils .shouldPreFilter (groups );
148201 if (shouldFilter ) {
202+ if (CLALibUtils .alreadyPreFiltered (groups , x .getNumColumns ()))
203+ return x ;
149204 final int nCol = x .getNumColumns ();
150205 final double [] constV = new double [nCol ];
151206 final List <AColGroup > filteredGroups = CLALibUtils .filterGroups (groups , constV );
0 commit comments