2626import java .io .ObjectOutput ;
2727import java .lang .ref .SoftReference ;
2828import java .util .ArrayList ;
29+ import java .util .HashSet ;
2930import java .util .Iterator ;
3031import java .util .List ;
32+ import java .util .Set ;
3133import java .util .concurrent .ExecutorService ;
3234import java .util .concurrent .Future ;
3335
4244import org .apache .sysds .runtime .DMLRuntimeException ;
4345import org .apache .sysds .runtime .compress .colgroup .AColGroup ;
4446import org .apache .sysds .runtime .compress .colgroup .AColGroup .CompressionType ;
47+ import org .apache .sysds .runtime .compress .colgroup .ADictBasedColGroup ;
4548import org .apache .sysds .runtime .compress .colgroup .ColGroupEmpty ;
4649import org .apache .sysds .runtime .compress .colgroup .ColGroupIO ;
4750import org .apache .sysds .runtime .compress .colgroup .ColGroupUncompressed ;
51+ import org .apache .sysds .runtime .compress .colgroup .dictionary .IDictionary ;
4852import org .apache .sysds .runtime .compress .lib .CLALibAppend ;
4953import org .apache .sysds .runtime .compress .lib .CLALibBinaryCellOp ;
5054import org .apache .sysds .runtime .compress .lib .CLALibCMOps ;
@@ -99,14 +103,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
99103 private static final Log LOG = LogFactory .getLog (CompressedMatrixBlock .class .getName ());
100104 private static final long serialVersionUID = 73193720143154058L ;
101105
102- /**
103- * Debugging flag for Compressed Matrices
104- */
106+ /** Debugging flag for Compressed Matrices */
105107 public static boolean debug = false ;
106108
107- /**
108- * Column groups
109- */
109+ /** Disallow caching of uncompressed Block */
110+ public static boolean allowCachingUncompressed = true ;
111+
112+ /** Column groups */
110113 protected transient List <AColGroup > _colGroups ;
111114
112115 /**
@@ -119,6 +122,9 @@ public class CompressedMatrixBlock extends MatrixBlock {
119122 */
120123 protected transient SoftReference <MatrixBlock > decompressedVersion ;
121124
125+ /** Cached Memory size */
126+ protected transient long cachedMemorySize = -1 ;
127+
122128 public CompressedMatrixBlock () {
123129 super (true );
124130 sparse = false ;
@@ -169,7 +175,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) {
169175 clen = uncompressedMatrixBlock .getNumColumns ();
170176 sparse = false ;
171177 nonZeros = uncompressedMatrixBlock .getNonZeros ();
172- decompressedVersion = new SoftReference <>(uncompressedMatrixBlock );
178+ if (!(uncompressedMatrixBlock instanceof CompressedMatrixBlock )) {
179+ decompressedVersion = new SoftReference <>(uncompressedMatrixBlock );
180+ }
173181 }
174182
175183 /**
@@ -189,6 +197,7 @@ public CompressedMatrixBlock(int rl, int cl, long nnz, boolean overlapping, List
189197 this .nonZeros = nnz ;
190198 this .overlappingColGroups = overlapping ;
191199 this ._colGroups = groups ;
200+ getInMemorySize (); // cache memory size
192201 }
193202
194203 @ Override
@@ -204,6 +213,7 @@ public void reset(int rl, int cl, boolean sp, long estnnz, double val) {
204213 * @param cg The column group to use after.
205214 */
206215 public void allocateColGroup (AColGroup cg ) {
216+ cachedMemorySize = -1 ;
207217 _colGroups = new ArrayList <>(1 );
208218 _colGroups .add (cg );
209219 }
@@ -270,6 +280,12 @@ public synchronized MatrixBlock decompress(int k) {
270280
271281 ret = CLALibDecompress .decompress (this , k );
272282
283+ if (ret .getNonZeros () <= 0 ) {
284+ LOG .warn ("Decompress incorrectly set nnz to 0 or -1" );
285+ ret .recomputeNonZeros (k );
286+ }
287+ ret .examSparsity (k );
288+
273289 // Set soft reference to the decompressed version
274290 decompressedVersion = new SoftReference <>(ret );
275291
@@ -290,7 +306,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
290306 * @return The cached decompressed matrix, if it does not exist return null
291307 */
292308 public MatrixBlock getCachedDecompressed () {
293- if (decompressedVersion != null ) {
309+ if ( allowCachingUncompressed && decompressedVersion != null ) {
294310 final MatrixBlock mb = decompressedVersion .get ();
295311 if (mb != null ) {
296312 DMLCompressionStatistics .addDecompressCacheCount ();
@@ -302,6 +318,7 @@ public MatrixBlock getCachedDecompressed() {
302318 }
303319
304320 public CompressedMatrixBlock squash (int k ) {
321+ cachedMemorySize = -1 ;
305322 return CLALibSquash .squash (this , k );
306323 }
307324
@@ -377,12 +394,27 @@ public long estimateSizeInMemory() {
377394 * @return an upper bound on the memory used to store this compressed block considering class overhead.
378395 */
379396 public long estimateCompressedSizeInMemory () {
380- long total = baseSizeInMemory ();
381397
382- for (AColGroup grp : _colGroups )
383- total += grp .estimateInMemorySize ();
398+ if (cachedMemorySize <= -1L ) {
399+
400+ long total = baseSizeInMemory ();
401+ // take into consideration duplicate dictionaries
402+ Set <IDictionary > dicts = new HashSet <>();
403+ for (AColGroup grp : _colGroups ){
404+ if (grp instanceof ADictBasedColGroup ){
405+ IDictionary dg = ((ADictBasedColGroup ) grp ).getDictionary ();
406+ if (dicts .contains (dg ))
407+ total -= dg .getInMemorySize ();
408+ dicts .add (dg );
409+ }
410+ total += grp .estimateInMemorySize ();
411+ }
412+ cachedMemorySize = total ;
413+ return total ;
384414
385- return total ;
415+ }
416+ else
417+ return cachedMemorySize ;
386418 }
387419
388420 public static long baseSizeInMemory () {
@@ -392,6 +424,7 @@ public static long baseSizeInMemory() {
392424 total += 8 ; // Col Group Ref
393425 total += 8 ; // v reference
394426 total += 8 ; // soft reference to decompressed version
427+ total += 8 ; // long cached memory size
395428 total += 1 + 7 ; // Booleans plus padding
396429
397430 total += 40 ; // Col Group Array List
@@ -431,6 +464,7 @@ public long estimateSizeOnDisk() {
431464
432465 @ Override
433466 public void readFields (DataInput in ) throws IOException {
467+ cachedMemorySize = -1 ;
434468 // deserialize compressed block
435469 rlen = in .readInt ();
436470 clen = in .readInt ();
@@ -736,8 +770,22 @@ public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows,
736770
737771 @ Override
738772 public boolean isEmptyBlock (boolean safe ) {
739- final long nonZeros = getNonZeros ();
740- return _colGroups == null || nonZeros == 0 || (nonZeros == -1 && recomputeNonZeros () == 0 );
773+ if (nonZeros > 1 )
774+ return false ;
775+ else if (_colGroups == null || nonZeros == 0 )
776+ return true ;
777+ else {
778+ if (nonZeros == -1 ){
779+ // try to use column groups
780+ for (AColGroup g : _colGroups )
781+ if (!g .isEmpty ())
782+ return false ;
783+ // Otherwise recompute non zeros.
784+ recomputeNonZeros ();
785+ }
786+
787+ return getNonZeros () == 0 ;
788+ }
741789 }
742790
743791 @ Override
@@ -1045,6 +1093,7 @@ public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareD
10451093 }
10461094
10471095 private void copyCompressedMatrix (CompressedMatrixBlock that ) {
1096+ cachedMemorySize = -1 ;
10481097 this .rlen = that .getNumRows ();
10491098 this .clen = that .getNumColumns ();
10501099 this .sparseBlock = null ;
@@ -1059,7 +1108,7 @@ private void copyCompressedMatrix(CompressedMatrixBlock that) {
10591108 }
10601109
10611110 public SoftReference <MatrixBlock > getSoftReferenceToDecompressed () {
1062- return decompressedVersion ;
1111+ return allowCachingUncompressed ? decompressedVersion : null ;
10631112 }
10641113
10651114 public void clearSoftReferenceToDecompressed () {
0 commit comments