Skip to content

Commit c42a629

Browse files
committed
[MINOR] Fused decompression in CLALibScalar
Closes #2169
1 parent e12d9d2 commit c42a629

File tree

1 file changed

+81
-3
lines changed

1 file changed

+81
-3
lines changed

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@
3131
import org.apache.sysds.hops.OptimizerUtils;
3232
import org.apache.sysds.runtime.DMLRuntimeException;
3333
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
34+
import org.apache.sysds.runtime.compress.DMLCompressionException;
3435
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
3536
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
3637
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
3738
import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
3839
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
3940
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
41+
import org.apache.sysds.runtime.data.DenseBlock;
4042
import org.apache.sysds.runtime.functionobjects.Divide;
4143
import org.apache.sysds.runtime.functionobjects.Minus;
4244
import org.apache.sysds.runtime.functionobjects.Multiply;
@@ -57,10 +59,13 @@ private CLALibScalar() {
5759
}
5860

5961
public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) {
62+
// Timing time = new Timing(true);
6063
if(isInvalidForCompressedOutput(m1, sop)) {
6164
LOG.warn("scalar overlapping not supported for op: " + sop.fn.getClass().getSimpleName());
62-
MatrixBlock m1d = m1.decompress(sop.getNumThreads());
63-
return m1d.scalarOperations(sop, result);
65+
66+
return fusedScalarAndDecompress(m1, sop);
67+
// MatrixBlock m1d = m1.decompress(sop.getNumThreads());
68+
// return m1d.scalarOperations(sop, result);
6469
}
6570
CompressedMatrixBlock ret = setupRet(m1, result);
6671

@@ -89,11 +94,84 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB
8994
ret.setOverlapping(m1.isOverlapping());
9095
}
9196

92-
ret.recomputeNonZeros();
97+
if(sop.fn instanceof Divide) {
98+
ret.setNonZeros(m1.getNonZeros());
99+
}
100+
else {
101+
ret.recomputeNonZeros();
102+
}
93103

104+
// System.out.println("CLA Scalar: " + sop + " " + m1.getNumRows() + ", " + m1.getNumColumns() + ", " +
105+
// m1.getColGroups().size()
106+
// + " -- " + "\t\t" + time.stop());
94107
return ret;
95108
}
96109

110+
private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) {
111+
int k = sop.getNumThreads();
112+
ExecutorService pool = CommonThreadPool.get(k);
113+
try {
114+
final int nRow = in.getNumRows();
115+
final int nCol = in.getNumColumns();
116+
final MatrixBlock out = new MatrixBlock(nRow, nCol, false);
117+
final List<AColGroup> groups = in.getColGroups();
118+
out.allocateDenseBlock();
119+
final DenseBlock db = out.getDenseBlock();
120+
final int blkz = Math.max((int)(Math.ceil((double)nRow / k)), 256);
121+
final List<Future<Long>> tasks = new ArrayList<>();
122+
for(int i = 0; i < nRow; i += blkz) {
123+
final int start = i;
124+
final int end = Math.min(i + blkz, nRow);
125+
tasks.add(pool.submit(() -> fusedDecompressAndScalar(groups, nCol, start, end, db, sop)));
126+
}
127+
long nnz = 0;
128+
for(Future<Long> t : tasks) {
129+
nnz += t.get();
130+
}
131+
out.setNonZeros(nnz);
132+
out.examSparsity(true, k);
133+
return out;
134+
}
135+
catch(Exception e) {
136+
throw new DMLCompressionException("failed fused scalar operation", e);
137+
}
138+
finally {
139+
pool.shutdown();
140+
}
141+
142+
// MatrixBlock m1d = m1.decompress(sop.getNumThreads());
143+
// return m1d.scalarOperations(sop, result);
144+
}
145+
146+
private static long fusedDecompressAndScalar(final List<AColGroup> groups, int nCol, int start, int end,
147+
DenseBlock db, ScalarOperator sop) {
148+
long nnz = 0;
149+
for(int b = start; b < end; b += 32) {
150+
int bs = b;
151+
int be = Math.min(b + 32, end);
152+
nnz += fusedDecompressAndScalarBlock(groups, nCol, bs, be, db, sop);
153+
}
154+
return nnz;
155+
}
156+
157+
private static long fusedDecompressAndScalarBlock(final List<AColGroup> groups, int nCol, int bs, int be,
158+
DenseBlock db, ScalarOperator sop) {
159+
long nnz = 0;
160+
for(AColGroup g : groups) {
161+
// main block to optimize is decompression speed since it is most likely an overlapping input
162+
g.decompressToDenseBlock(db, bs, be);
163+
}
164+
for(int r = bs; r < be; r++) {
165+
double[] vals = db.values(r);
166+
int off = db.pos(r);
167+
for(int c = off; c < nCol + off; c++) {
168+
vals[c] = sop.executeScalar(vals[c]);
169+
nnz += vals[c] == 0 ? 0 : 1;
170+
}
171+
}
172+
return nnz;
173+
}
174+
97175
private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) {
98176
CompressedMatrixBlock ret;
99177
if(result == null || !(result instanceof CompressedMatrixBlock))

0 commit comments

Comments
 (0)