Skip to content

Commit dc3947a

Browse files
Baunsgaarde-strauss
andcommitted
[SYSTEMDS-3824] Decompressing Transpose
Sebastian Baunsgaard <[email protected]> introduced a new CLALib for Reorg, specifically Transpose e-strauss <[email protected]> applied minor changes: - a manual rewrite in bultin kmeans script to use argmin (reduced runtime by 18%) - added new decompressing transpose to DenseBlock from SparseBlock for ColGroupDDC - fixed bug in sparsity evaluation in decompressed transposed (switch nrow w/ ncol) - minor bug fix in regarding the cached decompression count - fixed ctable with seq fuse rewrite fused ctable with given output dim (disaled: performance decrease, need to fix it first) - fixed null handling in fused seq ctable - fixed tests which passed for the wrong reason Co-authored-by: e-strauss <[email protected]> Co-authored-by: Sebastian Baunsgaard <[email protected]>
1 parent fd1ba7c commit dc3947a

File tree

12 files changed

+268
-87
lines changed

12 files changed

+268
-87
lines changed

scripts/builtin/kmeans.dml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
148148
P = D <= minD;
149149
# If some records belong to multiple centroids, share them equally
150150
P = P / rowSums (P);
151+
# P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids)
151152
# Compute the column normalization factor for P
152153
P_denom = colSums (P);
153154
# Compute new centroids as weighted averages over the records

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left )
651651

652652
try
653653
{
654+
// TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3
655+
// currently disabled due performance decrease
654656
if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) )
655657
{
656658
Hop input1 = getInput().get(0);

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
6060
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
6161
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
62+
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
6263
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
6364
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
6465
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double
633634

634635
@Override
635636
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
636-
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
637-
MatrixBlock tmp = decompress(op.getNumThreads());
638-
long nz = tmp.setNonZeros(tmp.getNonZeros());
639-
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
640-
tmp.setNonZeros(nz);
641-
return tmp;
642-
}
643-
else {
644-
// Allow transpose to be compressed output. In general we need to have a transposed flag on
645-
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
646-
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
647-
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
648-
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
649-
}
650-
637+
return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
651638
}
652639

653640
public boolean isOverlapping() {
@@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
13111298

13121299
@Override
13131300
public MatrixBlock transpose(int k) {
1314-
return getUncompressed().transpose(k);
1301+
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
13151302
}
13161303

13171304
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i
251251

252252
@Override
253253
protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
254-
throw new NotImplementedException();
254+
for(int i = rl; i < ru; i++) {
255+
final int vr = _data.getIndex(i);
256+
if(sb.isEmpty(vr))
257+
continue;
258+
final int apos = sb.pos(vr);
259+
final int alen = sb.size(vr) + apos;
260+
final int[] aix = sb.indexes(vr);
261+
final double[] aval = sb.values(vr);
262+
for(int j = apos; j < alen; j++) {
263+
final int rowOut = _colIndexes.get(aix[j]);
264+
final double[] c = db.values(rowOut);
265+
final int off = db.pos(rowOut);
266+
c[off + i] += aval[j];
267+
}
268+
}
255269
}
256270

257271
@Override
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.lib;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.Future;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
30+
import org.apache.sysds.runtime.compress.DMLCompressionException;
31+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
32+
import org.apache.sysds.runtime.data.DenseBlock;
33+
import org.apache.sysds.runtime.data.SparseBlock;
34+
import org.apache.sysds.runtime.data.SparseBlockMCSR;
35+
import org.apache.sysds.runtime.functionobjects.SwapIndex;
36+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
37+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
38+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
39+
import org.apache.sysds.runtime.util.CommonThreadPool;
40+
41+
public class CLALibReorg {
42+
43+
protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());
44+
45+
public static boolean warned = false;
46+
47+
public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
48+
int startColumn, int length) {
49+
// SwapIndex is transpose
50+
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
51+
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
52+
long nz = tmp.setNonZeros(tmp.getNonZeros());
53+
if(tmp.isInSparseFormat())
54+
return LibMatrixReorg.transpose(tmp); // edge case...
55+
else
56+
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
57+
tmp.setNonZeros(nz);
58+
return tmp;
59+
}
60+
else if(op.fn instanceof SwapIndex) {
61+
MatrixBlock tmp = cmb.getCachedDecompressed();
62+
if(tmp != null)
63+
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
64+
// Allow transpose to be compressed output. In general we need to have a transposed flag on
65+
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
66+
return transpose(cmb, ret, op.getNumThreads());
67+
}
68+
else {
69+
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
70+
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
71+
warned = true;
72+
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
73+
}
74+
}
75+
76+
private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
77+
78+
final long nnz = cmb.getNonZeros();
79+
final int nRow = cmb.getNumRows();
80+
final int nCol = cmb.getNumColumns();
81+
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz);
82+
if(sparseOut)
83+
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
84+
else
85+
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
86+
}
87+
88+
private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
89+
long nnz) {
90+
if(ret == null)
91+
ret = new MatrixBlock(nCol, nRow, true, nnz);
92+
else
93+
ret.reset(nCol, nRow, true, nnz);
94+
95+
ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);
96+
97+
final int nColOut = ret.getNumColumns();
98+
99+
if(k > 1 && cmb.getColGroups().size() > 1)
100+
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
101+
else
102+
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);
103+
104+
return ret;
105+
}
106+
107+
private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
108+
long nnz) {
109+
if(ret == null)
110+
ret = new MatrixBlock(nCol, nRow, false, nnz);
111+
else
112+
ret.reset(nCol, nRow, false, nnz);
113+
114+
// TODO: parallelize
115+
ret.allocateDenseBlock();
116+
117+
decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
118+
return ret;
119+
}
120+
121+
private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
122+
for(int i = 0; i < groups.size(); i++) {
123+
AColGroup g = groups.get(i);
124+
g.decompressToDenseBlockTransposed(ret, rl, ru);
125+
}
126+
}
127+
128+
private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
129+
int nColOut) {
130+
for(int i = 0; i < groups.size(); i++) {
131+
AColGroup g = groups.get(i);
132+
g.decompressToSparseBlockTransposed(ret, nColOut);
133+
}
134+
}
135+
136+
private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
137+
int k) {
138+
final ExecutorService pool = CommonThreadPool.get(k);
139+
try {
140+
final List<Future<?>> tasks = new ArrayList<>(groups.size());
141+
142+
for(int i = 0; i < groups.size(); i++) {
143+
final AColGroup g = groups.get(i);
144+
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
145+
}
146+
147+
for(Future<?> f : tasks)
148+
f.get();
149+
150+
}
151+
catch(Exception e) {
152+
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
153+
}
154+
finally {
155+
pool.shutdown();
156+
}
157+
}
158+
}

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
4040
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
4141
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
42+
import org.apache.sysds.runtime.matrix.data.Pair;
4243
import org.apache.sysds.runtime.util.CommonThreadPool;
4344
import org.apache.sysds.runtime.util.UtilFunctions;
4445

@@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int
7172

7273
try {
7374
final int[] map = new int[seqHeight];
74-
int maxCol = constructInitialMapping(map, A, k);
75+
Pair<Integer, Integer> meta = constructInitialMapping(map, A, k, nColOut);
76+
int maxCol = meta.getKey();
77+
int nZeros = meta.getValue();
7578
boolean containsNull = maxCol < 0;
7679
maxCol = Math.abs(maxCol);
7780

81+
boolean cutOff = false;
7882
if(nColOut == -1)
7983
nColOut = maxCol;
8084
else if(nColOut < maxCol)
81-
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
85+
cutOff = true;
8286

83-
final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
87+
if(containsNull)
88+
correctNulls(map, nColOut);
8489
if(nColOut == 0) // edge case of empty zero dimension block.
8590
return new MatrixBlock(seqHeight, 0, 0.0);
86-
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
91+
return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
8792
}
8893
catch(Exception e) {
8994
throw new RuntimeException("Failed table seq operator", e);
@@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) {
139144
return nNulls;
140145
}
141146

142-
private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
147+
private static Pair<Integer,Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
143148
if(A.isEmpty() || A.isInSparseFormat())
144149
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
145150
final MatrixBlock Ac;
@@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
155160
try {
156161

157162
int blkz = Math.max((map.length / k), 1000);
158-
List<Future<Integer>> tasks = new ArrayList<>();
163+
List<Future<Pair<Integer,Integer>>> tasks = new ArrayList<>();
159164
for(int i = 0; i < map.length; i += blkz) {
160165
final int start = i;
161166
final int end = Math.min(i + blkz, map.length);
162-
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end)));
167+
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol)));
163168
}
164169

165170
int maxCol = 0;
166-
for(Future<Integer> f : tasks) {
167-
int tmp = f.get();
168-
if(Math.abs(tmp) > Math.abs(maxCol))
169-
maxCol = tmp;
171+
int zeros = 0;
172+
for(Future<Pair<Integer,Integer>> f : tasks) {
173+
int tmpMaxCol = f.get().getKey();
174+
int tmpZeros = f.get().getValue();
175+
if(Math.abs(tmpMaxCol) > Math.abs(maxCol))
176+
maxCol = tmpMaxCol;
177+
zeros += tmpZeros;
170178
}
171-
return maxCol;
179+
return new Pair<Integer,Integer>(maxCol, zeros);
172180
}
173181
catch(Exception e) {
174182
throw new DMLRuntimeException(e);
@@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
179187

180188
}
181189

182-
private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
190+
private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {
183191

184192
int maxCol = 0;
185-
boolean containsNull = false;
186-
193+
int zeros = 0;
187194
final double[] aVals = A.getDenseBlockValues();
188195

189196
for(int i = start; i < end; i++) {
190197
final double v2 = aVals[i];
191-
if(Double.isNaN(v2)) {
192-
map[i] = -1; // assign temporarily to -1
193-
containsNull = true;
194-
}
195-
else {
196-
// safe casts to long for consistent behavior with indexing
197-
int col = UtilFunctions.toInt(v2);
198-
if(col <= 0)
199-
throw new DMLRuntimeException(
198+
final int colUnsafe = UtilFunctions.toInt(v2);
199+
if(!Double.isNaN(v2) && colUnsafe < 0)
200+
throw new DMLRuntimeException(
200201
"Erroneous input while computing the contingency table (value <= zero): " + v2);
202+
// Boolean to int conversion to avoid branch
203+
final int invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol) ? 1 : 0;
204+
// if invalid -> maxOutCol else -> colUnsafe - 1
205+
final int colSafe = maxOutCol*invalid + (colUnsafe - 1)*(1 - invalid);
206+
zeros += invalid;
207+
maxCol = Math.max(colUnsafe, maxCol);
208+
map[i] = colSafe;
209+
}
201210

202-
map[i] = col - 1;
203-
// maintain max seen col
204-
maxCol = Math.max(col, maxCol);
205-
}
211+
if (maxOutCol == -1 && zeros > 0){
212+
maxCol *= -1;
206213
}
207214

208-
return containsNull ? maxCol * -1 : maxCol;
215+
return new Pair<Integer, Integer>(maxCol, zeros);
209216
}
210217

211218
public static boolean compressedTableSeq() {

src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) {
138138
else if(ec.isMatrixObject(input1.getName()))
139139
processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root);
140140
else {
141-
throw new NotImplementedException("Not supported other types of input for compression than frame and matrix");
141+
LOG.warn("Compression on Scalar should not happen");
142+
ScalarObject Scalar = ec.getScalarInput(input1);
143+
ec.setScalarOutput(output.getName(),Scalar);
142144
}
143145
}
144146

0 commit comments

Comments
 (0)