Skip to content

Commit 749ec56

Browse files
Baunsgaarde-strauss
authored andcommitted
[SYSTEMDS-3824] Decompressing Transpose
Sebastian Baunsgaard <[email protected]> introduced a new CLALib for Reorg, specifically Transpose e-strauss <[email protected]> applied minor changes: - a manual rewrite in bultin kmeans script to use argmin (reduced runtime by 18%) - added new decompressing transpose to DenseBlock from SparseBlock for ColGroupDDC - fixed bug in sparsity evaluation in decompressed transposed (switch nrow w/ ncol) - minor bug fix in regarding the cached decompression count
1 parent e022eaf commit 749ec56

File tree

6 files changed

+189
-22
lines changed

6 files changed

+189
-22
lines changed

scripts/builtin/kmeans.dml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,11 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
145145
}
146146

147147
# Find the closest centroid for each record
148-
P = D <= minD;
148+
# P = D <= minD;
149149
# If some records belong to multiple centroids, share them equally
150-
P = P / rowSums (P);
150+
# P = P / rowSums (P);
151+
P = table(seq(1,nrow(D)), rowIndexMin(D))
152+
# P = table(seq(1,nrow(D)),compress(rowIndexMin(D)))
151153
# Compute the column normalization factor for P
152154
P_denom = colSums (P);
153155
# Compute new centroids as weighted averages over the records

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
6060
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
6161
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
62+
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
6263
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
6364
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
6465
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double
633634

634635
@Override
635636
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
636-
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
637-
MatrixBlock tmp = decompress(op.getNumThreads());
638-
long nz = tmp.setNonZeros(tmp.getNonZeros());
639-
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
640-
tmp.setNonZeros(nz);
641-
return tmp;
642-
}
643-
else {
644-
// Allow transpose to be compressed output. In general we need to have a transposed flag on
645-
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
646-
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
647-
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
648-
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
649-
}
650-
637+
return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
651638
}
652639

653640
public boolean isOverlapping() {
@@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
13111298

13121299
@Override
13131300
public MatrixBlock transpose(int k) {
1314-
return getUncompressed().transpose(k);
1301+
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
13151302
}
13161303

13171304
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i
251251

252252
@Override
253253
protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
254-
throw new NotImplementedException();
254+
for(int i = rl; i < ru; i++) {
255+
final int vr = _data.getIndex(i);
256+
if(sb.isEmpty(vr))
257+
continue;
258+
final int apos = sb.pos(vr);
259+
final int alen = sb.size(vr) + apos;
260+
final int[] aix = sb.indexes(vr);
261+
final double[] aval = sb.values(vr);
262+
for(int j = apos; j < alen; j++) {
263+
final int rowOut = _colIndexes.get(aix[j]);
264+
final double[] c = db.values(rowOut);
265+
final int off = db.pos(rowOut);
266+
c[off + i] += aval[j];
267+
}
268+
}
255269
}
256270

257271
@Override
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.lib;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.Future;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
30+
import org.apache.sysds.runtime.compress.DMLCompressionException;
31+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
32+
import org.apache.sysds.runtime.data.DenseBlock;
33+
import org.apache.sysds.runtime.data.SparseBlock;
34+
import org.apache.sysds.runtime.data.SparseBlockMCSR;
35+
import org.apache.sysds.runtime.functionobjects.SwapIndex;
36+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
37+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
38+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
39+
import org.apache.sysds.runtime.util.CommonThreadPool;
40+
41+
public class CLALibReorg {
42+
43+
protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());
44+
45+
public static boolean warned = false;
46+
47+
public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
48+
int startColumn, int length) {
49+
// SwapIndex is transpose
50+
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
51+
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
52+
long nz = tmp.setNonZeros(tmp.getNonZeros());
53+
if(tmp.isInSparseFormat())
54+
return LibMatrixReorg.transpose(tmp); // edge case...
55+
else
56+
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
57+
tmp.setNonZeros(nz);
58+
return tmp;
59+
}
60+
else if(op.fn instanceof SwapIndex) {
61+
MatrixBlock tmp = cmb.getCachedDecompressed();
62+
if(tmp != null)
63+
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
64+
// Allow transpose to be compressed output. In general we need to have a transposed flag on
65+
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
66+
return transpose(cmb, ret, op.getNumThreads());
67+
}
68+
else {
69+
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
70+
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
71+
warned = true;
72+
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
73+
}
74+
}
75+
76+
private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
77+
78+
final long nnz = cmb.getNonZeros();
79+
final int nRow = cmb.getNumRows();
80+
final int nCol = cmb.getNumColumns();
81+
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz);
82+
if(sparseOut)
83+
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
84+
else
85+
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
86+
}
87+
88+
private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
89+
long nnz) {
90+
if(ret == null)
91+
ret = new MatrixBlock(nCol, nRow, true, nnz);
92+
else
93+
ret.reset(nCol, nRow, true, nnz);
94+
95+
ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);
96+
97+
final int nColOut = ret.getNumColumns();
98+
99+
if(k > 1 && cmb.getColGroups().size() > 1)
100+
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
101+
else
102+
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);
103+
104+
return ret;
105+
}
106+
107+
private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
108+
long nnz) {
109+
if(ret == null)
110+
ret = new MatrixBlock(nCol, nRow, false, nnz);
111+
else
112+
ret.reset(nCol, nRow, false, nnz);
113+
114+
// TODO: parallelize
115+
ret.allocateDenseBlock();
116+
117+
decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
118+
return ret;
119+
}
120+
121+
private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
122+
for(int i = 0; i < groups.size(); i++) {
123+
AColGroup g = groups.get(i);
124+
g.decompressToDenseBlockTransposed(ret, rl, ru);
125+
}
126+
}
127+
128+
private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
129+
int nColOut) {
130+
for(int i = 0; i < groups.size(); i++) {
131+
AColGroup g = groups.get(i);
132+
g.decompressToSparseBlockTransposed(ret, nColOut);
133+
}
134+
}
135+
136+
private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
137+
int k) {
138+
final ExecutorService pool = CommonThreadPool.get(k);
139+
try {
140+
final List<Future<?>> tasks = new ArrayList<>(groups.size());
141+
142+
for(int i = 0; i < groups.size(); i++) {
143+
final AColGroup g = groups.get(i);
144+
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
145+
}
146+
147+
for(Future<?> f : tasks)
148+
f.get();
149+
150+
}
151+
catch(Exception e) {
152+
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
153+
}
154+
finally {
155+
pool.shutdown();
156+
}
157+
}
158+
}

src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) {
138138
else if(ec.isMatrixObject(input1.getName()))
139139
processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root);
140140
else {
141-
throw new NotImplementedException("Not supported other types of input for compression than frame and matrix");
141+
LOG.warn("Compression on Scalar should not happen");
142+
ScalarObject Scalar = ec.getScalarInput(input1);
143+
ec.setScalarOutput(output.getName(),Scalar);
142144
}
143145
}
144146

src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import org.apache.sysds.utils.Statistics;
3333
import org.junit.Assert;
3434

35+
import java.io.ByteArrayOutputStream;
36+
3537
public abstract class CompressBase extends AutomatedTestBase {
3638
// private static final Log LOG = LogFactory.getLog(CompressBase.class.getName());
3739

@@ -66,15 +68,17 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType,
6668
fullDMLScriptName = SCRIPT_DIR + "/functions/compress/compress_" + name + ".dml";
6769
programArgs = new String[] {"-stats", "100", "-nvargs", "A=" + input("A")};
6870

69-
String out = runTest(null).toString();
71+
ByteArrayOutputStream tmp = runTest(null);
72+
String out = tmp != null ? runTest(null).toString() : "";
7073

7174
int decompressCount = DMLCompressionStatistics.getDecompressionCount();
7275
long compressionCount = (instType == ExecType.SPARK) ? Statistics
7376
.getCPHeavyHitterCount("sp_compress") : Statistics.getCPHeavyHitterCount("compress");
7477
DMLCompressionStatistics.reset();
7578

7679
Assert.assertEquals(out + "\ncompression count wrong : ", compressionCount, compressionCountsExpected);
77-
Assert.assertTrue(out + "\nDecompression count wrong : ",
80+
Assert.assertTrue(out + "\nDecompression count wrong : " + decompressCount +
81+
(decompressionCountExpected >= 0 ? " [expected: " + decompressionCountExpected+ "]" : ""),
7882
decompressionCountExpected >= 0 ? decompressionCountExpected == decompressCount : decompressCount > 1);
7983

8084
}

0 commit comments

Comments
 (0)