Skip to content

Commit b96cf25

Browse files
committed
[MINOR] Compressed Dictionary Tests
This commit adds some (apparently) much needed tests primarily focusing on the Dictionary abstractions used for most of the dictionaries. These changes resulted in going from 100 lines of tests, to 3.3k changes to many files in the compression framework. Closes #2183
1 parent 6b37c85 commit b96cf25

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3353
-1832
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import java.io.ObjectOutput;
2727
import java.lang.ref.SoftReference;
2828
import java.util.ArrayList;
29+
import java.util.HashSet;
2930
import java.util.Iterator;
3031
import java.util.List;
32+
import java.util.Set;
3133
import java.util.concurrent.ExecutorService;
3234
import java.util.concurrent.Future;
3335

@@ -42,9 +44,11 @@
4244
import org.apache.sysds.runtime.DMLRuntimeException;
4345
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
4446
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
47+
import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup;
4548
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
4649
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
4750
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
51+
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
4852
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
4953
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
5054
import org.apache.sysds.runtime.compress.lib.CLALibCMOps;
@@ -99,14 +103,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
99103
private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName());
100104
private static final long serialVersionUID = 73193720143154058L;
101105

102-
/**
103-
* Debugging flag for Compressed Matrices
104-
*/
106+
/** Debugging flag for Compressed Matrices */
105107
public static boolean debug = false;
106108

107-
/**
108-
* Column groups
109-
*/
109+
/** Disallow caching of uncompressed Block */
110+
public static boolean allowCachingUncompressed = true;
111+
112+
/** Column groups */
110113
protected transient List<AColGroup> _colGroups;
111114

112115
/**
@@ -119,6 +122,9 @@ public class CompressedMatrixBlock extends MatrixBlock {
119122
*/
120123
protected transient SoftReference<MatrixBlock> decompressedVersion;
121124

125+
/** Cached Memory size */
126+
protected transient long cachedMemorySize = -1;
127+
122128
public CompressedMatrixBlock() {
123129
super(true);
124130
sparse = false;
@@ -169,7 +175,9 @@ protected CompressedMatrixBlock(MatrixBlock uncompressedMatrixBlock) {
169175
clen = uncompressedMatrixBlock.getNumColumns();
170176
sparse = false;
171177
nonZeros = uncompressedMatrixBlock.getNonZeros();
172-
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
178+
if(!(uncompressedMatrixBlock instanceof CompressedMatrixBlock)) {
179+
decompressedVersion = new SoftReference<>(uncompressedMatrixBlock);
180+
}
173181
}
174182

175183
/**
@@ -189,6 +197,7 @@ public CompressedMatrixBlock(int rl, int cl, long nnz, boolean overlapping, List
189197
this.nonZeros = nnz;
190198
this.overlappingColGroups = overlapping;
191199
this._colGroups = groups;
200+
getInMemorySize(); // cache memory size
192201
}
193202

194203
@Override
@@ -204,6 +213,7 @@ public void reset(int rl, int cl, boolean sp, long estnnz, double val) {
204213
* @param cg The column group to use after.
205214
*/
206215
public void allocateColGroup(AColGroup cg) {
216+
cachedMemorySize = -1;
207217
_colGroups = new ArrayList<>(1);
208218
_colGroups.add(cg);
209219
}
@@ -270,6 +280,12 @@ public synchronized MatrixBlock decompress(int k) {
270280

271281
ret = CLALibDecompress.decompress(this, k);
272282

283+
if(ret.getNonZeros() <= 0) {
284+
LOG.warn("Decompress incorrectly set nnz to 0 or -1");
285+
ret.recomputeNonZeros(k);
286+
}
287+
ret.examSparsity(k);
288+
273289
// Set soft reference to the decompressed version
274290
decompressedVersion = new SoftReference<>(ret);
275291

@@ -290,7 +306,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
290306
* @return The cached decompressed matrix, if it does not exist return null
291307
*/
292308
public MatrixBlock getCachedDecompressed() {
293-
if(decompressedVersion != null) {
309+
if( allowCachingUncompressed && decompressedVersion != null) {
294310
final MatrixBlock mb = decompressedVersion.get();
295311
if(mb != null) {
296312
DMLCompressionStatistics.addDecompressCacheCount();
@@ -302,6 +318,7 @@ public MatrixBlock getCachedDecompressed() {
302318
}
303319

304320
public CompressedMatrixBlock squash(int k) {
321+
cachedMemorySize = -1;
305322
return CLALibSquash.squash(this, k);
306323
}
307324

@@ -377,12 +394,27 @@ public long estimateSizeInMemory() {
377394
* @return an upper bound on the memory used to store this compressed block considering class overhead.
378395
*/
379396
public long estimateCompressedSizeInMemory() {
380-
long total = baseSizeInMemory();
381397

382-
for(AColGroup grp : _colGroups)
383-
total += grp.estimateInMemorySize();
398+
if(cachedMemorySize <= -1L) {
399+
400+
long total = baseSizeInMemory();
401+
// take into consideration duplicate dictionaries
402+
Set<IDictionary> dicts = new HashSet<>();
403+
for(AColGroup grp : _colGroups){
404+
if(grp instanceof ADictBasedColGroup){
405+
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
406+
if(dicts.contains(dg))
407+
total -= dg.getInMemorySize();
408+
dicts.add(dg);
409+
}
410+
total += grp.estimateInMemorySize();
411+
}
412+
cachedMemorySize = total;
413+
return total;
384414

385-
return total;
415+
}
416+
else
417+
return cachedMemorySize;
386418
}
387419

388420
public static long baseSizeInMemory() {
@@ -392,6 +424,7 @@ public static long baseSizeInMemory() {
392424
total += 8; // Col Group Ref
393425
total += 8; // v reference
394426
total += 8; // soft reference to decompressed version
427+
total += 8; // long cached memory size
395428
total += 1 + 7; // Booleans plus padding
396429

397430
total += 40; // Col Group Array List
@@ -431,6 +464,7 @@ public long estimateSizeOnDisk() {
431464

432465
@Override
433466
public void readFields(DataInput in) throws IOException {
467+
cachedMemorySize = -1;
434468
// deserialize compressed block
435469
rlen = in.readInt();
436470
clen = in.readInt();
@@ -736,8 +770,22 @@ public MatrixBlock rexpandOperations(MatrixBlock ret, double max, boolean rows,
736770

737771
@Override
738772
public boolean isEmptyBlock(boolean safe) {
739-
final long nonZeros = getNonZeros();
740-
return _colGroups == null || nonZeros == 0 || (nonZeros == -1 && recomputeNonZeros() == 0);
773+
if(nonZeros > 1)
774+
return false;
775+
else if(_colGroups == null || nonZeros == 0)
776+
return true;
777+
else{
778+
if(nonZeros == -1){
779+
// try to use column groups
780+
for(AColGroup g : _colGroups)
781+
if(!g.isEmpty())
782+
return false;
783+
// Otherwise recompute non zeros.
784+
recomputeNonZeros();
785+
}
786+
787+
return getNonZeros() == 0;
788+
}
741789
}
742790

743791
@Override
@@ -1045,6 +1093,7 @@ public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, boolean awareD
10451093
}
10461094

10471095
private void copyCompressedMatrix(CompressedMatrixBlock that) {
1096+
cachedMemorySize = -1;
10481097
this.rlen = that.getNumRows();
10491098
this.clen = that.getNumColumns();
10501099
this.sparseBlock = null;
@@ -1059,7 +1108,7 @@ private void copyCompressedMatrix(CompressedMatrixBlock that) {
10591108
}
10601109

10611110
public SoftReference<MatrixBlock> getSoftReferenceToDecompressed() {
1062-
return decompressedVersion;
1111+
return allowCachingUncompressed ? decompressedVersion : null;
10631112
}
10641113

10651114
public void clearSoftReferenceToDecompressed() {

src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.List;
2727
import java.util.Set;
2828

29+
import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
2930
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
3031
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
3132
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -63,8 +64,8 @@ public IDictionary getDictionary() {
6364

6465
@Override
6566
public final void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) {
66-
if(_dict instanceof IdentityDictionary) {
67-
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
67+
if(_dict instanceof AIdentityDictionary) {
68+
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
6869
final MatrixBlock mb = md.getMatrixBlock();
6970
// The dictionary is never empty.
7071
if(mb.isInSparseFormat())
@@ -87,8 +88,8 @@ else if(_dict instanceof MatrixBlockDictionary) {
8788

8889
@Override
8990
public void decompressToSparseBlockTransposed(SparseBlockMCSR sb, int nColOut) {
90-
if(_dict instanceof IdentityDictionary) {
91-
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
91+
if(_dict instanceof AIdentityDictionary) {
92+
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
9293
final MatrixBlock mb = md.getMatrixBlock();
9394
// The dictionary is never empty.
9495
if(mb.isInSparseFormat())
@@ -123,8 +124,8 @@ protected abstract void decompressToSparseBlockTransposedDenseDictionary(SparseB
123124

124125
@Override
125126
public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) {
126-
if(_dict instanceof IdentityDictionary) {
127-
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
127+
if(_dict instanceof AIdentityDictionary) {
128+
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
128129
final MatrixBlock mb = md.getMatrixBlock();
129130
// The dictionary is never empty.
130131
if(mb.isInSparseFormat())
@@ -147,9 +148,8 @@ else if(_dict instanceof MatrixBlockDictionary) {
147148

148149
@Override
149150
public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC) {
150-
if(_dict instanceof IdentityDictionary) {
151-
152-
final MatrixBlockDictionary md = ((IdentityDictionary) _dict).getMBDict();
151+
if(_dict instanceof AIdentityDictionary) {
152+
final MatrixBlockDictionary md = ((AIdentityDictionary) _dict).getMBDict();
153153
final MatrixBlock mb = md.getMatrixBlock();
154154
// The dictionary is never empty.
155155
if(mb.isInSparseFormat())
@@ -249,8 +249,8 @@ public final AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols, i
249249
return null;
250250

251251
// is candidate for identity mm.
252-
if(_dict instanceof IdentityDictionary //
253-
&& !((IdentityDictionary) _dict).withEmpty()
252+
if(_dict instanceof AIdentityDictionary //
253+
&& !((AIdentityDictionary) _dict).withEmpty()
254254
&& right.getNumRows() == _colIndexes.size() //
255255
&& allowShallowIdentityRightMult()){
256256

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
import org.apache.commons.lang3.NotImplementedException;
2727
import org.apache.sysds.runtime.compress.DMLCompressionException;
2828
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
29+
import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
2930
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
3031
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
3132
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
32-
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
3333
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
3434
import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict;
3535
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
@@ -327,8 +327,8 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa
327327
* @param constV The output columns.
328328
*/
329329
public final void addToCommon(double[] constV) {
330-
if(_dict instanceof IdentityDictionary) {
331-
MatrixBlock mb = ((IdentityDictionary) _dict).getMBDict().getMatrixBlock();
330+
if(_dict instanceof AIdentityDictionary) {
331+
MatrixBlock mb = ((AIdentityDictionary) _dict).getMBDict().getMatrixBlock();
332332
if(mb.isInSparseFormat())
333333
addToCommonSparse(constV, mb.getSparseBlock());
334334
else
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
* O
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.colgroup.dictionary;
21+
22+
import java.lang.ref.SoftReference;
23+
24+
public abstract class ACachingMBDictionary extends ADictionary {
25+
26+
/** A Cache to contain a materialized version of the identity matrix. */
27+
protected volatile SoftReference<MatrixBlockDictionary> cache = null;
28+
29+
@Override
30+
public final MatrixBlockDictionary getMBDict(int nCol) {
31+
if(cache != null) {
32+
MatrixBlockDictionary r = cache.get();
33+
if(r != null)
34+
return r;
35+
}
36+
MatrixBlockDictionary ret = createMBDict(nCol);
37+
cache = new SoftReference<>(ret);
38+
return ret;
39+
}
40+
41+
public abstract MatrixBlockDictionary createMBDict(int nCol);
42+
}

0 commit comments

Comments
 (0)