Skip to content

Commit c3d88a4

Browse files
committed
[SYSTEMDS-???] de-compressing Transpose
1 parent f7af63f commit c3d88a4

File tree

2 files changed

+174
-31
lines changed

2 files changed

+174
-31
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
5858
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
5959
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
60+
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
6061
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
6162
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
6263
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -306,7 +307,7 @@ public void putInto(MatrixBlock target, int rowOffset, int colOffset, boolean sp
306307
* @return The cached decompressed matrix, if it does not exist return null
307308
*/
308309
public MatrixBlock getCachedDecompressed() {
309-
if( allowCachingUncompressed && decompressedVersion != null) {
310+
if(allowCachingUncompressed && decompressedVersion != null) {
310311
final MatrixBlock mb = decompressedVersion.get();
311312
if(mb != null) {
312313
DMLCompressionStatistics.addDecompressCacheCount();
@@ -400,8 +401,8 @@ public long estimateCompressedSizeInMemory() {
400401
long total = baseSizeInMemory();
401402
// take into consideration duplicate dictionaries
402403
Set<IDictionary> dicts = new HashSet<>();
403-
for(AColGroup grp : _colGroups){
404-
if(grp instanceof ADictBasedColGroup){
404+
for(AColGroup grp : _colGroups) {
405+
if(grp instanceof ADictBasedColGroup) {
405406
IDictionary dg = ((ADictBasedColGroup) grp).getDictionary();
406407
if(dicts.contains(dg))
407408
total -= dg.getInMemorySize();
@@ -575,8 +576,7 @@ public void append(MatrixValue v2, ArrayList<IndexedMatrixValue> outlist, int bl
575576
}
576577

577578
@Override
578-
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype,
579-
int k) {
579+
public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) {
580580

581581
checkMMChain(ctype, v, w);
582582
// multi-threaded MMChain of single uncompressed ColGroup
@@ -653,21 +653,7 @@ else if(isOverlapping()) {
653653

654654
@Override
655655
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
656-
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
657-
MatrixBlock tmp = decompress(op.getNumThreads());
658-
long nz = tmp.setNonZeros(tmp.getNonZeros());
659-
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
660-
tmp.setNonZeros(nz);
661-
return tmp;
662-
}
663-
else {
664-
// Allow transpose to be compressed output. In general we need to have a transposed flag on
665-
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
666-
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
667-
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
668-
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
669-
}
670-
656+
return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
671657
}
672658

673659
public boolean isOverlapping() {
@@ -709,10 +695,10 @@ public boolean containsValue(double pattern) {
709695
return false;
710696
}
711697
}
712-
698+
713699
@Override
714700
public boolean containsValue(double pattern, int k) {
715-
//TODO parallel contains value
701+
// TODO parallel contains value
716702
return containsValue(pattern);
717703
}
718704

@@ -774,8 +760,8 @@ public boolean isEmptyBlock(boolean safe) {
774760
return false;
775761
else if(_colGroups == null || nonZeros == 0)
776762
return true;
777-
else{
778-
if(nonZeros == -1){
763+
else {
764+
if(nonZeros == -1) {
779765
// try to use column groups
780766
for(AColGroup g : _colGroups)
781767
if(!g.isEmpty())
@@ -1176,8 +1162,7 @@ public void appendRow(int r, SparseRow row, boolean deep) {
11761162
}
11771163

11781164
@Override
1179-
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset,
1180-
boolean deep) {
1165+
public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, int rowoffset, int coloffset, boolean deep) {
11811166
throw new DMLCompressionException("Can't append row to compressed Matrix");
11821167
}
11831168

@@ -1237,7 +1222,7 @@ public void sparseToDense(int k) {
12371222
}
12381223

12391224
@Override
1240-
public void denseToSparse(boolean allowCSR, int k){
1225+
public void denseToSparse(boolean allowCSR, int k) {
12411226
// do nothing
12421227
}
12431228

@@ -1326,13 +1311,13 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
13261311
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
13271312
}
13281313

1329-
@Override
1314+
@Override
13301315
public MatrixBlock transpose(int k) {
1331-
return getUncompressed().transpose(k);
1316+
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
13321317
}
13331318

1334-
@Override
1335-
public MatrixBlock reshape(int rows,int cols, boolean byRow){
1319+
@Override
1320+
public MatrixBlock reshape(int rows, int cols, boolean byRow) {
13361321
return CLALibReshape.reshape(this, rows, cols, byRow);
13371322
}
13381323

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.compress.lib;
21+
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.Future;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
30+
import org.apache.sysds.runtime.compress.DMLCompressionException;
31+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
32+
import org.apache.sysds.runtime.data.DenseBlock;
33+
import org.apache.sysds.runtime.data.SparseBlock;
34+
import org.apache.sysds.runtime.data.SparseBlockMCSR;
35+
import org.apache.sysds.runtime.functionobjects.SwapIndex;
36+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
37+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
38+
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
39+
import org.apache.sysds.runtime.util.CommonThreadPool;
40+
41+
public class CLALibReorg {
42+
43+
protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());
44+
45+
public static boolean warned = false;
46+
47+
public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
48+
int startColumn, int length) {
49+
// SwapIndex is transpose
50+
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
51+
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
52+
long nz = tmp.setNonZeros(tmp.getNonZeros());
53+
if(tmp.isInSparseFormat())
54+
return LibMatrixReorg.transpose(tmp); // edge case...
55+
else
56+
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
57+
tmp.setNonZeros(nz);
58+
return tmp;
59+
}
60+
else if(op.fn instanceof SwapIndex) {
61+
if(cmb.getCachedDecompressed() != null)
62+
return cmb.getCachedDecompressed().reorgOperations(op, ret, startRow, startColumn, length);
63+
64+
return transpose(cmb, ret, op.getNumThreads());
65+
}
66+
else {
67+
// Allow transpose to be compressed output. In general we need to have a transposed flag on
68+
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
69+
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
70+
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
71+
warned = true;
72+
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
73+
}
74+
}
75+
76+
private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
77+
78+
final long nnz = cmb.getNonZeros();
79+
final int nRow = cmb.getNumRows();
80+
final int nCol = cmb.getNumColumns();
81+
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nRow, nCol, nnz);
82+
if(sparseOut)
83+
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
84+
else
85+
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
86+
}
87+
88+
private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
89+
long nnz) {
90+
if(ret == null)
91+
ret = new MatrixBlock(nCol, nRow, true, nnz);
92+
else
93+
ret.reset(nCol, nRow, true, nnz);
94+
95+
ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);
96+
97+
final int nColOut = ret.getNumColumns();
98+
99+
if(k > 1)
100+
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
101+
else
102+
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);
103+
104+
return ret;
105+
}
106+
107+
private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
108+
long nnz) {
109+
if(ret == null)
110+
ret = new MatrixBlock(nCol, nRow, false, nnz);
111+
else
112+
ret.reset(nCol, nRow, false, nnz);
113+
114+
// TODO: parallelize
115+
ret.allocateDenseBlock();
116+
117+
decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
118+
return ret;
119+
}
120+
121+
private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
122+
for(int i = 0; i < groups.size(); i++) {
123+
AColGroup g = groups.get(i);
124+
g.decompressToDenseBlockTransposed(ret, rl, ru);
125+
}
126+
}
127+
128+
private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
129+
int nColOut) {
130+
for(int i = 0; i < groups.size(); i++) {
131+
AColGroup g = groups.get(i);
132+
g.decompressToSparseBlockTransposed(ret, nColOut);
133+
}
134+
}
135+
136+
private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
137+
int k) {
138+
final ExecutorService pool = CommonThreadPool.get(k);
139+
try {
140+
final List<Future<?>> tasks = new ArrayList<>(groups.size());
141+
142+
for(int i = 0; i < groups.size(); i++) {
143+
final AColGroup g = groups.get(i);
144+
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
145+
}
146+
147+
for(Future<?> f : tasks)
148+
f.get();
149+
150+
}
151+
catch(Exception e) {
152+
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
153+
}
154+
finally {
155+
pool.shutdown();
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)