Skip to content

Commit 53f72ed

Browse files
committed
[MINOR] Update decompression for zeros
This commit adds a check for decompressing overlapping matrices to remove very small epsilons from zero to round values to zero on overlapping decompression. Previously the compressed state could make a sparse matrix dense because of these rounding errors. Closes #2170
1 parent c42a629 commit 53f72ed

File tree

2 files changed

+183
-71
lines changed

2 files changed

+183
-71
lines changed

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java

Lines changed: 140 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import java.util.concurrent.ExecutorService;
2727
import java.util.concurrent.Future;
2828

29+
import org.apache.commons.lang3.NotImplementedException;
2930
import org.apache.commons.logging.Log;
3031
import org.apache.commons.logging.LogFactory;
3132
import org.apache.sysds.api.DMLScript;
33+
import org.apache.sysds.runtime.DMLRuntimeException;
3234
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
3335
import org.apache.sysds.runtime.compress.DMLCompressionException;
3436
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
@@ -65,6 +67,11 @@ public static MatrixBlock decompress(CompressedMatrixBlock cmb, int k) {
6567

6668
public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k,
6769
boolean countNNz) {
70+
decompressTo(cmb, ret, rowOffset, colOffset, k, countNNz, false);
71+
}
72+
73+
public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k,
74+
boolean countNNz, boolean reset) {
6875
Timing time = new Timing(true);
6976
if(cmb.getNumColumns() + colOffset > ret.getNumColumns() || cmb.getNumRows() + rowOffset > ret.getNumRows()) {
7077
LOG.warn(
@@ -78,12 +85,12 @@ public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int
7885

7986
final boolean outSparse = ret.isInSparseFormat();
8087
if(!cmb.isEmpty()) {
81-
if(outSparse && cmb.isOverlapping())
88+
if(outSparse && (cmb.isOverlapping() || reset))
8289
throw new DMLCompressionException("Not supported decompression into sparse block from overlapping state");
8390
else if(outSparse)
8491
decompressToSparseBlock(cmb, ret, rowOffset, colOffset);
8592
else
86-
decompressToDenseBlock(cmb, ret.getDenseBlock(), rowOffset, colOffset);
93+
decompressToDenseBlock(cmb, ret.getDenseBlock(), rowOffset, colOffset, k, reset);
8794
}
8895

8996
if(DMLScript.STATISTICS) {
@@ -94,7 +101,7 @@ else if(outSparse)
94101
}
95102

96103
if(countNNz)
97-
ret.recomputeNonZeros();
104+
ret.recomputeNonZeros(k);
98105
}
99106

100107
private static void decompressToSparseBlock(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset,
@@ -115,23 +122,67 @@ private static void decompressToSparseBlock(CompressedMatrixBlock cmb, MatrixBlo
115122
ret.checkSparseRows();
116123
}
117124

118-
private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock ret, int rowOffset, int colOffset) {
119-
final List<AColGroup> groups = cmb.getColGroups();
125+
private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock ret, int rowOffset, int colOffset,
126+
int k, boolean reset) {
127+
List<AColGroup> groups = cmb.getColGroups();
120128
// final int nCols = cmb.getNumColumns();
121129
final int nRows = cmb.getNumRows();
122130

123131
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
124-
if(shouldFilter) {
132+
if(shouldFilter && !CLALibUtils.alreadyPreFiltered(groups, cmb.getNumColumns())) {
125133
final double[] constV = new double[cmb.getNumColumns()];
126-
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
127-
for(AColGroup g : filteredGroups)
128-
g.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset);
134+
groups = CLALibUtils.filterGroups(groups, constV);
129135
AColGroup cRet = ColGroupConst.create(constV);
130-
cRet.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset);
136+
groups.add(cRet);
131137
}
132-
else {
133-
for(AColGroup g : groups)
134-
g.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset);
138+
139+
if(k > 1 && nRows > 1000)
140+
decompressToDenseBlockParallel(ret, groups, rowOffset, colOffset, nRows, k, reset);
141+
else
142+
decompressToDenseBlockSingleThread(ret, groups, rowOffset, colOffset, nRows, reset);
143+
}
144+
145+
private static void decompressToDenseBlockSingleThread(DenseBlock ret, List<AColGroup> groups, int rowOffset,
146+
int colOffset, int nRows, boolean reset) {
147+
decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, 0, nRows, reset);
148+
}
149+
150+
private static void decompressToDenseBlockBlock(DenseBlock ret, List<AColGroup> groups, int rowOffset, int colOffset,
151+
int rl, int ru, boolean reset) {
152+
if(reset) {
153+
if(ret.isContiguous()) {
154+
final int nCol = ret.getDim(1);
155+
ret.fillBlock(0, rl * nCol, ru * nCol, 0.0);
156+
}
157+
else
158+
throw new NotImplementedException();
159+
}
160+
for(AColGroup g : groups)
161+
g.decompressToDenseBlock(ret, rl, ru, rowOffset, colOffset);
162+
}
163+
164+
private static void decompressToDenseBlockParallel(DenseBlock ret, List<AColGroup> groups, int rowOffset,
165+
int colOffset, int nRows, int k, boolean reset) {
166+
167+
final int blklen = Math.max(nRows / k, 512);
168+
final ExecutorService pool = CommonThreadPool.get(k);
169+
try {
170+
List<Future<?>> tasks = new ArrayList<>(nRows / blklen);
171+
for(int r = 0; r < nRows; r += blklen) {
172+
final int start = r;
173+
final int end = Math.min(nRows, r + blklen);
174+
tasks.add(
175+
pool.submit(() -> decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, start, end, reset)));
176+
}
177+
178+
for(Future<?> t : tasks)
179+
t.get();
180+
}
181+
catch(Exception e) {
182+
throw new DMLCompressionException("Failed parallel decompress to");
183+
}
184+
finally {
185+
pool.shutdown();
135186
}
136187
}
137188

@@ -148,7 +199,7 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) {
148199
MatrixBlock ret = getUncompressedColGroupAndRemoveFromListOfColGroups(groups, overlapping, nRows, nCols);
149200

150201
if(ret != null && groups.size() == 0) {
151-
ret.setNonZeros(ret.recomputeNonZeros());
202+
ret.setNonZeros(ret.recomputeNonZeros(k));
152203
return ret; // if uncompressedColGroup is only colGroup.
153204
}
154205

@@ -182,23 +233,18 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) {
182233
constV = null;
183234

184235
final double eps = getEps(constV);
185-
186236
if(k == 1) {
187-
if(ret.isInSparseFormat()) {
237+
if(ret.isInSparseFormat())
188238
decompressSparseSingleThread(ret, filteredGroups, nRows, blklen);
189-
}
190-
else {
239+
else
191240
decompressDenseSingleThread(ret, filteredGroups, nRows, blklen, constV, eps, nonZeros, overlapping);
192-
}
193241
}
194-
else if(ret.isInSparseFormat()) {
242+
else if(ret.isInSparseFormat())
195243
decompressSparseMultiThread(ret, filteredGroups, nRows, blklen, k);
196-
}
197-
else {
244+
else
198245
decompressDenseMultiThread(ret, filteredGroups, nRows, blklen, constV, eps, k, overlapping);
199-
}
200246

201-
ret.recomputeNonZeros();
247+
ret.recomputeNonZeros(k);
202248
ret.examSparsity();
203249

204250
return ret;
@@ -249,29 +295,40 @@ private static void decompressSparseSingleThread(MatrixBlock ret, List<AColGroup
249295

250296
private static void decompressDenseSingleThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen,
251297
int blklen, double[] constV, double eps, long nonZeros, boolean overlapping) {
298+
299+
final DenseBlock db = ret.getDenseBlock();
300+
final int nCol = ret.getNumColumns();
252301
for(int i = 0; i < rlen; i += blklen) {
253302
final int rl = i;
254303
final int ru = Math.min(i + blklen, rlen);
255304
for(AColGroup grp : filteredGroups)
256-
grp.decompressToDenseBlock(ret.getDenseBlock(), rl, ru);
305+
grp.decompressToDenseBlock(db, rl, ru);
257306
if(constV != null && !ret.isInSparseFormat())
258-
addVector(ret, constV, eps, rl, ru);
307+
addVector(db, nCol, constV, eps, rl, ru);
259308
}
260309
}
261310

262-
protected static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, int k,
263-
boolean overlapping) {
264-
final int nRows = ret.getNumRows();
265-
final double eps = getEps(constV);
266-
final int blklen = Math.max(nRows / k, 512);
267-
decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping);
268-
}
311+
// private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, int k,
312+
// boolean overlapping) {
313+
// final int nRows = ret.getNumRows();
314+
// final double eps = getEps(constV);
315+
// final int blklen = Math.max(nRows / k, 512);
316+
// decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping);
317+
// }
269318

270319
protected static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV,
271320
double eps, int k, boolean overlapping) {
321+
322+
Timing time = new Timing(true);
272323
final int nRows = ret.getNumRows();
273324
final int blklen = Math.max(nRows / k, 512);
274325
decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping);
326+
if(DMLScript.STATISTICS) {
327+
final double t = time.stop();
328+
DMLCompressionStatistics.addDecompressTime(t, k);
329+
if(LOG.isTraceEnabled())
330+
LOG.trace("decompressed block w/ k=" + k + " in " + t + "ms.");
331+
}
275332
}
276333

277334
private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen,
@@ -297,7 +354,7 @@ private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup>
297354
catch(InterruptedException | ExecutionException ex) {
298355
throw new DMLCompressionException("Parallel decompression failed", ex);
299356
}
300-
finally{
357+
finally {
301358
pool.shutdown();
302359
}
303360
}
@@ -310,13 +367,14 @@ private static void decompressSparseMultiThread(MatrixBlock ret, List<AColGroup>
310367
for(int i = 0; i < rlen; i += blklen)
311368
tasks.add(new DecompressSparseTask(filteredGroups, ret, i, Math.min(i + blklen, rlen)));
312369

370+
LOG.error("tasks:" + tasks);
313371
for(Future<Object> rt : pool.invokeAll(tasks))
314372
rt.get();
315373
}
316374
catch(InterruptedException | ExecutionException ex) {
317375
throw new DMLCompressionException("Parallel decompression failed", ex);
318376
}
319-
finally{
377+
finally {
320378
pool.shutdown();
321379
}
322380
}
@@ -360,22 +418,23 @@ protected DecompressDenseTask(List<AColGroup> colGroups, MatrixBlock ret, double
360418
_eps = eps;
361419
_rl = rl;
362420
_ru = ru;
363-
_blklen = 32768 / ret.getNumColumns();
421+
_blklen = Math.max(32768 / ret.getNumColumns(), 128);
364422
_constV = constV;
365423
}
366424

367425
@Override
368426
public Long call() {
369427
try {
370-
428+
final DenseBlock db = _ret.getDenseBlock();
429+
final int nCol = _ret.getNumColumns();
371430
long nnz = 0;
372431
for(int b = _rl; b < _ru; b += _blklen) {
373432
final int e = Math.min(b + _blklen, _ru);
374433
for(AColGroup grp : _colGroups)
375-
grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e);
434+
grp.decompressToDenseBlock(db, b, e);
376435

377436
if(_constV != null)
378-
addVector(_ret, _constV, _eps, b, e);
437+
addVector(db, nCol, _constV, _eps, b, e);
379438
nnz += _ret.recomputeNonZeros(b, e - 1);
380439
}
381440

@@ -404,23 +463,22 @@ protected DecompressDenseSingleColTask(AColGroup grp, MatrixBlock ret, double ep
404463
_eps = eps;
405464
_rl = rl;
406465
_ru = ru;
407-
_blklen = 32768 / ret.getNumColumns();
466+
_blklen = Math.max(32768 / ret.getNumColumns(), 128);
408467
_constV = constV;
409468
}
410469

411470
@Override
412471
public Long call() {
413472
try {
414-
473+
final DenseBlock db = _ret.getDenseBlock();
474+
final int nCol = _ret.getNumColumns();
415475
long nnz = 0;
416476
for(int b = _rl; b < _ru; b += _blklen) {
417477
final int e = Math.min(b + _blklen, _ru);
418-
// for(AColGroup grp : _colGroups)
419-
_grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e);
478+
_grp.decompressToDenseBlock(db, b, e);
420479

421480
if(_constV != null)
422-
addVector(_ret, _constV, _eps, b, e);
423-
// nnz += _ret.recomputeNonZeros(b, e - 1);
481+
addVector(db, nCol, _constV, _eps, b, e);
424482
}
425483

426484
return nnz;
@@ -446,14 +504,21 @@ protected DecompressSparseTask(List<AColGroup> colGroups, MatrixBlock ret, int r
446504
}
447505

448506
@Override
449-
public Object call() {
450-
final SparseBlock sb = _ret.getSparseBlock();
451-
for(AColGroup grp : _colGroups)
452-
grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru);
453-
for(int i = _rl; i < _ru; i++)
454-
if(!sb.isEmpty(i))
455-
sb.sort(i);
456-
return null;
507+
public Object call() throws Exception{
508+
try{
509+
510+
final SparseBlock sb = _ret.getSparseBlock();
511+
for(AColGroup grp : _colGroups)
512+
grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru);
513+
for(int i = _rl; i < _ru; i++)
514+
if(!sb.isEmpty(i))
515+
sb.sort(i);
516+
return null;
517+
}
518+
catch(Exception e){
519+
e.printStackTrace();
520+
throw new DMLRuntimeException(e);
521+
}
457522
}
458523
}
459524

@@ -467,28 +532,32 @@ public Object call() {
467532
* @param rl The row to start at
468533
* @param ru The row to end at
469534
*/
470-
private static void addVector(final MatrixBlock ret, final double[] rowV, final double eps, final int rl,
471-
final int ru) {
472-
final int nCols = ret.getNumColumns();
473-
final DenseBlock db = ret.getDenseBlock();
535+
private static final void addVector(final DenseBlock db, final int nCols, final double[] rowV, final double eps,
536+
final int rl, final int ru) {
537+
if(eps == 0)
538+
addVectorEps(db, nCols, rowV, eps, rl, ru);
539+
else
540+
addVectorNoEps(db, nCols, rowV, eps, rl, ru);
541+
}
474542

475-
if(nCols == 1) {
476-
if(eps == 0)
477-
addValue(db.values(0), rowV[0], rl, ru);
478-
else
479-
addValueEps(db.values(0), rowV[0], eps, rl, ru);
480-
}
481-
else if(db.isContiguous()) {
482-
if(eps == 0)
483-
addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, ru);
484-
else
485-
addVectorContiguousEps(db.values(0), rowV, nCols, eps, rl, ru);
486-
}
487-
else if(eps == 0)
543+
private static final void addVectorEps(final DenseBlock db, final int nCols, final double[] rowV, final double eps,
544+
final int rl, final int ru) {
545+
if(nCols == 1)
546+
addValue(db.values(0), rowV[0], rl, ru);
547+
else if(db.isContiguous())
548+
addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, ru);
549+
else
488550
addVectorNoEps(db, rowV, nCols, rl, ru);
551+
}
552+
553+
private static final void addVectorNoEps(final DenseBlock db, final int nCols, final double[] rowV, final double eps,
554+
final int rl, final int ru) {
555+
if(nCols == 1)
556+
addValueEps(db.values(0), rowV[0], eps, rl, ru);
557+
else if(db.isContiguous())
558+
addVectorContiguousEps(db.values(0), rowV, nCols, eps, rl, ru);
489559
else
490560
addVectorEps(db, rowV, nCols, eps, rl, ru);
491-
492561
}
493562

494563
private static void addValue(final double[] retV, final double v, final int rl, final int ru) {

0 commit comments

Comments
 (0)