2626import java .util .concurrent .ExecutorService ;
2727import java .util .concurrent .Future ;
2828
29+ import org .apache .commons .lang3 .NotImplementedException ;
2930import org .apache .commons .logging .Log ;
3031import org .apache .commons .logging .LogFactory ;
3132import org .apache .sysds .api .DMLScript ;
33+ import org .apache .sysds .runtime .DMLRuntimeException ;
3234import org .apache .sysds .runtime .compress .CompressedMatrixBlock ;
3335import org .apache .sysds .runtime .compress .DMLCompressionException ;
3436import org .apache .sysds .runtime .compress .colgroup .AColGroup ;
@@ -65,6 +67,11 @@ public static MatrixBlock decompress(CompressedMatrixBlock cmb, int k) {
6567
6668 public static void decompressTo (CompressedMatrixBlock cmb , MatrixBlock ret , int rowOffset , int colOffset , int k ,
6769 boolean countNNz ) {
70+ decompressTo (cmb , ret , rowOffset , colOffset , k , countNNz , false );
71+ }
72+
73+ public static void decompressTo (CompressedMatrixBlock cmb , MatrixBlock ret , int rowOffset , int colOffset , int k ,
74+ boolean countNNz , boolean reset ) {
6875 Timing time = new Timing (true );
6976 if (cmb .getNumColumns () + colOffset > ret .getNumColumns () || cmb .getNumRows () + rowOffset > ret .getNumRows ()) {
7077 LOG .warn (
@@ -78,12 +85,12 @@ public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int
7885
7986 final boolean outSparse = ret .isInSparseFormat ();
8087 if (!cmb .isEmpty ()) {
81- if (outSparse && cmb .isOverlapping ())
88+ if (outSparse && ( cmb .isOverlapping () || reset ))
8289 throw new DMLCompressionException ("Not supported decompression into sparse block from overlapping state" );
8390 else if (outSparse )
8491 decompressToSparseBlock (cmb , ret , rowOffset , colOffset );
8592 else
86- decompressToDenseBlock (cmb , ret .getDenseBlock (), rowOffset , colOffset );
93+ decompressToDenseBlock (cmb , ret .getDenseBlock (), rowOffset , colOffset , k , reset );
8794 }
8895
8996 if (DMLScript .STATISTICS ) {
@@ -94,7 +101,7 @@ else if(outSparse)
94101 }
95102
96103 if (countNNz )
97- ret .recomputeNonZeros ();
104+ ret .recomputeNonZeros (k );
98105 }
99106
100107 private static void decompressToSparseBlock (CompressedMatrixBlock cmb , MatrixBlock ret , int rowOffset ,
@@ -115,23 +122,67 @@ private static void decompressToSparseBlock(CompressedMatrixBlock cmb, MatrixBlo
115122 ret .checkSparseRows ();
116123 }
117124
118- private static void decompressToDenseBlock (CompressedMatrixBlock cmb , DenseBlock ret , int rowOffset , int colOffset ) {
119- final List <AColGroup > groups = cmb .getColGroups ();
125+ private static void decompressToDenseBlock (CompressedMatrixBlock cmb , DenseBlock ret , int rowOffset , int colOffset ,
126+ int k , boolean reset ) {
127+ List <AColGroup > groups = cmb .getColGroups ();
120128 // final int nCols = cmb.getNumColumns();
121129 final int nRows = cmb .getNumRows ();
122130
123131 final boolean shouldFilter = CLALibUtils .shouldPreFilter (groups );
124- if (shouldFilter ) {
132+ if (shouldFilter && ! CLALibUtils . alreadyPreFiltered ( groups , cmb . getNumColumns ()) ) {
125133 final double [] constV = new double [cmb .getNumColumns ()];
126- final List <AColGroup > filteredGroups = CLALibUtils .filterGroups (groups , constV );
127- for (AColGroup g : filteredGroups )
128- g .decompressToDenseBlock (ret , 0 , nRows , rowOffset , colOffset );
134+ groups = CLALibUtils .filterGroups (groups , constV );
129135 AColGroup cRet = ColGroupConst .create (constV );
130- cRet . decompressToDenseBlock ( ret , 0 , nRows , rowOffset , colOffset );
136+ groups . add ( cRet );
131137 }
132- else {
133- for (AColGroup g : groups )
134- g .decompressToDenseBlock (ret , 0 , nRows , rowOffset , colOffset );
138+
139+ if (k > 1 && nRows > 1000 )
140+ decompressToDenseBlockParallel (ret , groups , rowOffset , colOffset , nRows , k , reset );
141+ else
142+ decompressToDenseBlockSingleThread (ret , groups , rowOffset , colOffset , nRows , reset );
143+ }
144+
145+ private static void decompressToDenseBlockSingleThread (DenseBlock ret , List <AColGroup > groups , int rowOffset ,
146+ int colOffset , int nRows , boolean reset ) {
147+ decompressToDenseBlockBlock (ret , groups , rowOffset , colOffset , 0 , nRows , reset );
148+ }
149+
150+ private static void decompressToDenseBlockBlock (DenseBlock ret , List <AColGroup > groups , int rowOffset , int colOffset ,
151+ int rl , int ru , boolean reset ) {
152+ if (reset ) {
153+ if (ret .isContiguous ()) {
154+ final int nCol = ret .getDim (1 );
155+ ret .fillBlock (0 , rl * nCol , ru * nCol , 0.0 );
156+ }
157+ else
158+ throw new NotImplementedException ();
159+ }
160+ for (AColGroup g : groups )
161+ g .decompressToDenseBlock (ret , rl , ru , rowOffset , colOffset );
162+ }
163+
164+ private static void decompressToDenseBlockParallel (DenseBlock ret , List <AColGroup > groups , int rowOffset ,
165+ int colOffset , int nRows , int k , boolean reset ) {
166+
167+ final int blklen = Math .max (nRows / k , 512 );
168+ final ExecutorService pool = CommonThreadPool .get (k );
169+ try {
170+ List <Future <?>> tasks = new ArrayList <>(nRows / blklen );
171+ for (int r = 0 ; r < nRows ; r += blklen ) {
172+ final int start = r ;
173+ final int end = Math .min (nRows , r + blklen );
174+ tasks .add (
175+ pool .submit (() -> decompressToDenseBlockBlock (ret , groups , rowOffset , colOffset , start , end , reset )));
176+ }
177+
178+ for (Future <?> t : tasks )
179+ t .get ();
180+ }
181+ catch (Exception e ) {
182+ throw new DMLCompressionException ("Failed parallel decompress to" );
183+ }
184+ finally {
185+ pool .shutdown ();
135186 }
136187 }
137188
@@ -148,7 +199,7 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) {
148199 MatrixBlock ret = getUncompressedColGroupAndRemoveFromListOfColGroups (groups , overlapping , nRows , nCols );
149200
150201 if (ret != null && groups .size () == 0 ) {
151- ret .setNonZeros (ret .recomputeNonZeros ());
202+ ret .setNonZeros (ret .recomputeNonZeros (k ));
152203 return ret ; // if uncompressedColGroup is only colGroup.
153204 }
154205
@@ -182,23 +233,18 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) {
182233 constV = null ;
183234
184235 final double eps = getEps (constV );
185-
186236 if (k == 1 ) {
187- if (ret .isInSparseFormat ()) {
237+ if (ret .isInSparseFormat ())
188238 decompressSparseSingleThread (ret , filteredGroups , nRows , blklen );
189- }
190- else {
239+ else
191240 decompressDenseSingleThread (ret , filteredGroups , nRows , blklen , constV , eps , nonZeros , overlapping );
192- }
193241 }
194- else if (ret .isInSparseFormat ()) {
242+ else if (ret .isInSparseFormat ())
195243 decompressSparseMultiThread (ret , filteredGroups , nRows , blklen , k );
196- }
197- else {
244+ else
198245 decompressDenseMultiThread (ret , filteredGroups , nRows , blklen , constV , eps , k , overlapping );
199- }
200246
201- ret .recomputeNonZeros ();
247+ ret .recomputeNonZeros (k );
202248 ret .examSparsity ();
203249
204250 return ret ;
@@ -249,29 +295,40 @@ private static void decompressSparseSingleThread(MatrixBlock ret, List<AColGroup
249295
250296 private static void decompressDenseSingleThread (MatrixBlock ret , List <AColGroup > filteredGroups , int rlen ,
251297 int blklen , double [] constV , double eps , long nonZeros , boolean overlapping ) {
298+
299+ final DenseBlock db = ret .getDenseBlock ();
300+ final int nCol = ret .getNumColumns ();
252301 for (int i = 0 ; i < rlen ; i += blklen ) {
253302 final int rl = i ;
254303 final int ru = Math .min (i + blklen , rlen );
255304 for (AColGroup grp : filteredGroups )
256- grp .decompressToDenseBlock (ret . getDenseBlock () , rl , ru );
305+ grp .decompressToDenseBlock (db , rl , ru );
257306 if (constV != null && !ret .isInSparseFormat ())
258- addVector (ret , constV , eps , rl , ru );
307+ addVector (db , nCol , constV , eps , rl , ru );
259308 }
260309 }
261310
262- protected static void decompressDenseMultiThread (MatrixBlock ret , List <AColGroup > groups , double [] constV , int k ,
263- boolean overlapping ) {
264- final int nRows = ret .getNumRows ();
265- final double eps = getEps (constV );
266- final int blklen = Math .max (nRows / k , 512 );
267- decompressDenseMultiThread (ret , groups , nRows , blklen , constV , eps , k , overlapping );
268- }
311+ // private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, int k,
312+ // boolean overlapping) {
313+ // final int nRows = ret.getNumRows();
314+ // final double eps = getEps(constV);
315+ // final int blklen = Math.max(nRows / k, 512);
316+ // decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping);
317+ // }
269318
270319 protected static void decompressDenseMultiThread (MatrixBlock ret , List <AColGroup > groups , double [] constV ,
271320 double eps , int k , boolean overlapping ) {
321+
322+ Timing time = new Timing (true );
272323 final int nRows = ret .getNumRows ();
273324 final int blklen = Math .max (nRows / k , 512 );
274325 decompressDenseMultiThread (ret , groups , nRows , blklen , constV , eps , k , overlapping );
326+ if (DMLScript .STATISTICS ) {
327+ final double t = time .stop ();
328+ DMLCompressionStatistics .addDecompressTime (t , k );
329+ if (LOG .isTraceEnabled ())
330+ LOG .trace ("decompressed block w/ k=" + k + " in " + t + "ms." );
331+ }
275332 }
276333
277334 private static void decompressDenseMultiThread (MatrixBlock ret , List <AColGroup > filteredGroups , int rlen , int blklen ,
@@ -297,7 +354,7 @@ private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup>
297354 catch (InterruptedException | ExecutionException ex ) {
298355 throw new DMLCompressionException ("Parallel decompression failed" , ex );
299356 }
300- finally {
357+ finally {
301358 pool .shutdown ();
302359 }
303360 }
@@ -310,13 +367,14 @@ private static void decompressSparseMultiThread(MatrixBlock ret, List<AColGroup>
310367 for (int i = 0 ; i < rlen ; i += blklen )
311368 tasks .add (new DecompressSparseTask (filteredGroups , ret , i , Math .min (i + blklen , rlen )));
312369
370+ LOG .error ("tasks:" + tasks );
313371 for (Future <Object > rt : pool .invokeAll (tasks ))
314372 rt .get ();
315373 }
316374 catch (InterruptedException | ExecutionException ex ) {
317375 throw new DMLCompressionException ("Parallel decompression failed" , ex );
318376 }
319- finally {
377+ finally {
320378 pool .shutdown ();
321379 }
322380 }
@@ -360,22 +418,23 @@ protected DecompressDenseTask(List<AColGroup> colGroups, MatrixBlock ret, double
360418 _eps = eps ;
361419 _rl = rl ;
362420 _ru = ru ;
363- _blklen = 32768 / ret .getNumColumns ();
421+ _blklen = Math . max ( 32768 / ret .getNumColumns (), 128 );
364422 _constV = constV ;
365423 }
366424
367425 @ Override
368426 public Long call () {
369427 try {
370-
428+ final DenseBlock db = _ret .getDenseBlock ();
429+ final int nCol = _ret .getNumColumns ();
371430 long nnz = 0 ;
372431 for (int b = _rl ; b < _ru ; b += _blklen ) {
373432 final int e = Math .min (b + _blklen , _ru );
374433 for (AColGroup grp : _colGroups )
375- grp .decompressToDenseBlock (_ret . getDenseBlock () , b , e );
434+ grp .decompressToDenseBlock (db , b , e );
376435
377436 if (_constV != null )
378- addVector (_ret , _constV , _eps , b , e );
437+ addVector (db , nCol , _constV , _eps , b , e );
379438 nnz += _ret .recomputeNonZeros (b , e - 1 );
380439 }
381440
@@ -404,23 +463,22 @@ protected DecompressDenseSingleColTask(AColGroup grp, MatrixBlock ret, double ep
404463 _eps = eps ;
405464 _rl = rl ;
406465 _ru = ru ;
407- _blklen = 32768 / ret .getNumColumns ();
466+ _blklen = Math . max ( 32768 / ret .getNumColumns (), 128 );
408467 _constV = constV ;
409468 }
410469
411470 @ Override
412471 public Long call () {
413472 try {
414-
473+ final DenseBlock db = _ret .getDenseBlock ();
474+ final int nCol = _ret .getNumColumns ();
415475 long nnz = 0 ;
416476 for (int b = _rl ; b < _ru ; b += _blklen ) {
417477 final int e = Math .min (b + _blklen , _ru );
418- // for(AColGroup grp : _colGroups)
419- _grp .decompressToDenseBlock (_ret .getDenseBlock (), b , e );
478+ _grp .decompressToDenseBlock (db , b , e );
420479
421480 if (_constV != null )
422- addVector (_ret , _constV , _eps , b , e );
423- // nnz += _ret.recomputeNonZeros(b, e - 1);
481+ addVector (db , nCol , _constV , _eps , b , e );
424482 }
425483
426484 return nnz ;
@@ -446,14 +504,21 @@ protected DecompressSparseTask(List<AColGroup> colGroups, MatrixBlock ret, int r
446504 }
447505
448506 @ Override
449- public Object call () {
450- final SparseBlock sb = _ret .getSparseBlock ();
451- for (AColGroup grp : _colGroups )
452- grp .decompressToSparseBlock (_ret .getSparseBlock (), _rl , _ru );
453- for (int i = _rl ; i < _ru ; i ++)
454- if (!sb .isEmpty (i ))
455- sb .sort (i );
456- return null ;
507+ public Object call () throws Exception {
508+ try {
509+
510+ final SparseBlock sb = _ret .getSparseBlock ();
511+ for (AColGroup grp : _colGroups )
512+ grp .decompressToSparseBlock (_ret .getSparseBlock (), _rl , _ru );
513+ for (int i = _rl ; i < _ru ; i ++)
514+ if (!sb .isEmpty (i ))
515+ sb .sort (i );
516+ return null ;
517+ }
518+ catch (Exception e ){
519+ e .printStackTrace ();
520+ throw new DMLRuntimeException (e );
521+ }
457522 }
458523 }
459524
@@ -467,28 +532,32 @@ public Object call() {
467532 * @param rl The row to start at
468533 * @param ru The row to end at
469534 */
470- private static void addVector (final MatrixBlock ret , final double [] rowV , final double eps , final int rl ,
471- final int ru ) {
472- final int nCols = ret .getNumColumns ();
473- final DenseBlock db = ret .getDenseBlock ();
535+ private static final void addVector (final DenseBlock db , final int nCols , final double [] rowV , final double eps ,
536+ final int rl , final int ru ) {
537+ if (eps == 0 )
538+ addVectorEps (db , nCols , rowV , eps , rl , ru );
539+ else
540+ addVectorNoEps (db , nCols , rowV , eps , rl , ru );
541+ }
474542
475- if (nCols == 1 ) {
476- if (eps == 0 )
477- addValue (db .values (0 ), rowV [0 ], rl , ru );
478- else
479- addValueEps (db .values (0 ), rowV [0 ], eps , rl , ru );
480- }
481- else if (db .isContiguous ()) {
482- if (eps == 0 )
483- addVectorContiguousNoEps (db .values (0 ), rowV , nCols , rl , ru );
484- else
485- addVectorContiguousEps (db .values (0 ), rowV , nCols , eps , rl , ru );
486- }
487- else if (eps == 0 )
543+ private static final void addVectorEps (final DenseBlock db , final int nCols , final double [] rowV , final double eps ,
544+ final int rl , final int ru ) {
545+ if (nCols == 1 )
546+ addValue (db .values (0 ), rowV [0 ], rl , ru );
547+ else if (db .isContiguous ())
548+ addVectorContiguousNoEps (db .values (0 ), rowV , nCols , rl , ru );
549+ else
488550 addVectorNoEps (db , rowV , nCols , rl , ru );
551+ }
552+
553+ private static final void addVectorNoEps (final DenseBlock db , final int nCols , final double [] rowV , final double eps ,
554+ final int rl , final int ru ) {
555+ if (nCols == 1 )
556+ addValueEps (db .values (0 ), rowV [0 ], eps , rl , ru );
557+ else if (db .isContiguous ())
558+ addVectorContiguousEps (db .values (0 ), rowV , nCols , eps , rl , ru );
489559 else
490560 addVectorEps (db , rowV , nCols , eps , rl , ru );
491-
492561 }
493562
494563 private static void addValue (final double [] retV , final double v , final int rl , final int ru ) {
0 commit comments