3131import org .apache .sysds .hops .OptimizerUtils ;
3232import org .apache .sysds .runtime .DMLRuntimeException ;
3333import org .apache .sysds .runtime .compress .CompressedMatrixBlock ;
34+ import org .apache .sysds .runtime .compress .DMLCompressionException ;
3435import org .apache .sysds .runtime .compress .colgroup .AColGroup ;
3536import org .apache .sysds .runtime .compress .colgroup .ColGroupConst ;
3637import org .apache .sysds .runtime .compress .colgroup .ColGroupEmpty ;
3738import org .apache .sysds .runtime .compress .colgroup .ColGroupOLE ;
3839import org .apache .sysds .runtime .compress .colgroup .ColGroupUncompressed ;
3940import org .apache .sysds .runtime .compress .colgroup .indexes .IColIndex ;
41+ import org .apache .sysds .runtime .data .DenseBlock ;
4042import org .apache .sysds .runtime .functionobjects .Divide ;
4143import org .apache .sysds .runtime .functionobjects .Minus ;
4244import org .apache .sysds .runtime .functionobjects .Multiply ;
@@ -57,10 +59,13 @@ private CLALibScalar() {
5759 }
5860
5961 public static MatrixBlock scalarOperations (ScalarOperator sop , CompressedMatrixBlock m1 , MatrixValue result ) {
62+ // Timing time = new Timing(true);
6063 if (isInvalidForCompressedOutput (m1 , sop )) {
6164 LOG .warn ("scalar overlapping not supported for op: " + sop .fn .getClass ().getSimpleName ());
62- MatrixBlock m1d = m1 .decompress (sop .getNumThreads ());
63- return m1d .scalarOperations (sop , result );
65+
66+ return fusedScalarAndDecompress (m1 , sop );
67+ // MatrixBlock m1d = m1.decompress(sop.getNumThreads());
68+ // return m1d.scalarOperations(sop, result);
6469 }
6570 CompressedMatrixBlock ret = setupRet (m1 , result );
6671
@@ -89,11 +94,84 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB
8994 ret .setOverlapping (m1 .isOverlapping ());
9095 }
9196
92- ret .recomputeNonZeros ();
97+ if (sop .fn instanceof Divide ) {
98+ ret .setNonZeros (m1 .getNonZeros ());
99+ }
100+ else {
101+ ret .recomputeNonZeros ();
102+ }
93103
104+ // System.out.println("CLA Scalar: " + sop + " " + m1.getNumRows() + ", " + m1.getNumColumns() + ", " +
105+ // m1.getColGroups().size()
106+ // + " -- " + "\t\t" + time.stop());
94107 return ret ;
95108 }
96109
110+ private static MatrixBlock fusedScalarAndDecompress (CompressedMatrixBlock in , ScalarOperator sop ) {
111+ int k = sop .getNumThreads ();
112+ ExecutorService pool = CommonThreadPool .get (k );
113+ try {
114+ final int nRow = in .getNumRows ();
115+ final int nCol = in .getNumColumns ();
116+ final MatrixBlock out = new MatrixBlock (nRow , nCol , false );
117+ final List <AColGroup > groups = in .getColGroups ();
118+ out .allocateDenseBlock ();
119+ final DenseBlock db = out .getDenseBlock ();
120+ final int blkz = Math .max ((int )(Math .ceil ((double )nRow / k )), 256 );
121+ final List <Future <Long >> tasks = new ArrayList <>();
122+ for (int i = 0 ; i < nRow ; i += blkz ) {
123+ final int start = i ;
124+ final int end = Math .min (i + blkz , nRow );
125+ tasks .add (pool .submit (() -> fusedDecompressAndScalar (groups , nCol , start , end , db , sop )));
126+ }
127+ long nnz = 0 ;
128+ for (Future <Long > t : tasks ) {
129+ nnz += t .get ();
130+ }
131+ out .setNonZeros (nnz );
132+ out .examSparsity (true , k );
133+ return out ;
134+ }
135+ catch (Exception e ) {
136+ throw new DMLCompressionException ("failed fused scalar operation" , e );
137+ }
138+ finally {
139+ pool .shutdown ();
140+ }
141+
142+ // MatrixBlock m1d = m1.decompress(sop.getNumThreads());
143+ // return m1d.scalarOperations(sop, result);
144+ }
145+
146+ private static long fusedDecompressAndScalar (final List <AColGroup > groups , int nCol , int start , int end ,
147+ DenseBlock db , ScalarOperator sop ) {
148+ long nnz = 0 ;
149+ for (int b = start ; b < end ; b += 32 ) {
150+ int bs = b ;
151+ int be = Math .min (b + 32 , end );
152+ nnz += fusedDecompressAndScalarBlock (groups , nCol , bs , be , db , sop );
153+ }
154+ return nnz ;
155+ }
156+
157+ private static long fusedDecompressAndScalarBlock (final List <AColGroup > groups , int nCol , int bs , int be ,
158+ DenseBlock db , ScalarOperator sop ) {
159+ long nnz = 0 ;
160+ for (AColGroup g : groups ) {
161+ // main block to optimize is decompression speed since it is most likely an overlapping input
162+ g .decompressToDenseBlock (db , bs , be );
163+ }
164+ for (int r = bs ; r < be ; r ++) {
165+ double [] vals = db .values (r );
166+ int off = db .pos (r );
167+ for (int c = off ; c < nCol + off ; c ++) {
168+ vals [c ] = sop .executeScalar (vals [c ]);
169+ nnz += vals [c ] == 0 ? 0 : 1 ;
170+ }
171+ }
172+ return nnz ;
173+ }
174+
97175 private static CompressedMatrixBlock setupRet (CompressedMatrixBlock m1 , MatrixValue result ) {
98176 CompressedMatrixBlock ret ;
99177 if (result == null || !(result instanceof CompressedMatrixBlock ))
0 commit comments