Skip to content

Commit b751389

Browse files
committed
[SYSTEMDS-3827] CLA MultiCBind
This commit adds specialized support for n way cbind in compressed space. Closes #2208
1 parent cdff385 commit b751389

File tree

8 files changed

+263
-50
lines changed

8 files changed

+263
-50
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
5151
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
5252
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
53-
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
5453
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
54+
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
5555
import org.apache.sysds.runtime.compress.lib.CLALibCMOps;
5656
import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
5757
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
@@ -556,8 +556,8 @@ public MatrixBlock binaryOperationsLeft(BinaryOperator op, MatrixValue thatValue
556556

557557
@Override
558558
public MatrixBlock append(MatrixBlock[] that, MatrixBlock ret, boolean cbind) {
559-
if(cbind && that.length == 1)
560-
return CLALibAppend.append(this, that[0], InfrastructureAnalyzer.getLocalParallelism());
559+
if(cbind)
560+
return CLALibCBind.cbind(this, that, InfrastructureAnalyzer.getLocalParallelism());
561561
else {
562562
MatrixBlock left = getUncompressed("append list or r-bind not supported in compressed");
563563
MatrixBlock[] thatUC = new MatrixBlock[that.length];

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ public AMapToData getMapToData() {
655655
public double getSparsity() {
656656
return 1.0;
657657
}
658-
658+
659659
@Override
660660
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
661661
throw new NotImplementedException();
@@ -710,12 +710,10 @@ protected void decompressToSparseBlockTransposedDenseDictionary(SparseBlockMCSR
710710
public AColGroup combineWithSameIndex(int nRow, int nCol, AColGroup right) {
711711
if(!(right instanceof ColGroupConst))
712712
return super.combineWithSameIndex(nRow, nCol, right);
713-
714713
final IColIndex combIndex = _colIndexes.combine(right.getColIndices().shift(nCol));
715714
final IDictionary b = ((ColGroupConst) right).getDictionary();
716715
final IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
717716
return create(combIndex, combined);
718-
719717
}
720718

721719
@Override
@@ -737,10 +735,11 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List<AColGroup> right)
737735
for(int i = 0; i < right.size(); i++) {
738736
AColGroup g = right.get(i);
739737

740-
if(!(g instanceof ColGroupConst) || !(g instanceof ColGroupEmpty)) {
738+
if(!(g instanceof ColGroupConst) && !(g instanceof ColGroupEmpty)) {
741739
return super.combineWithSameIndex(nRow, nCol, right);
742740
}
743741
}
742+
744743
IColIndex combinedIndex = _colIndexes;
745744
int i = 0;
746745
for(AColGroup g : right) {
@@ -751,7 +750,7 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List<AColGroup> right)
751750

752751
return create(combinedIndex, combined);
753752
}
754-
753+
755754
@Override
756755
protected boolean allowShallowIdentityRightMult() {
757756
return true;

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -549,25 +549,12 @@ public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, List<AColGro
549549
final IColIndex combinedColIndex = combineColIndexes(nCol, right);
550550
final double[] combinedDefaultTuple = IContainDefaultTuple.combineDefaultTuples(_reference, right);
551551

552-
// return new ColGroupDDC(combinedColIndex, combined, _data, getCachedCounts());
553-
return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
554-
getCachedCounts());
552+
return new ColGroupSDCFOR(combinedColIndex, this.getNumRows(), combined, _indexes, _data, getCachedCounts(),
553+
combinedDefaultTuple);
555554
}
556555

557556
@Override
558557
public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, AColGroup right) {
559-
// if(right instanceof ColGroupSDCZeros){
560-
// ColGroupSDCZeros rightSDC = ((ColGroupSDCZeros) right);
561-
// IDictionary b = rightSDC.getDictionary();
562-
// IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
563-
// IColIndex combinedColIndex = _colIndexes.combine(right.getColIndices().shift(nCol));
564-
// double[] combinedDefaultTuple = new double[_reference.length + right.getNumCols()];
565-
// System.arraycopy(_reference, 0, combinedDefaultTuple, 0, _reference.length);
566-
567-
// return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
568-
// getCachedCounts());
569-
// }
570-
// else{
571558
ColGroupSDCFOR rightSDC = ((ColGroupSDCFOR) right);
572559
IDictionary b = rightSDC.getDictionary();
573560
IDictionary combined = DictionaryFactory.cBindDictionaries(_dict, b, this.getNumCols(), right.getNumCols());
@@ -576,9 +563,8 @@ public AColGroupCompressed combineWithSameIndex(int nRow, int nCol, AColGroup ri
576563
System.arraycopy(_reference, 0, combinedDefaultTuple, 0, _reference.length);
577564
System.arraycopy(rightSDC._reference, 0, combinedDefaultTuple, _reference.length, rightSDC._reference.length);
578565

579-
return new ColGroupSDC(combinedColIndex, this.getNumRows(), combined, combinedDefaultTuple, _indexes, _data,
580-
getCachedCounts());
581-
// }
566+
return new ColGroupSDCFOR(combinedColIndex, this.getNumRows(), combined, _indexes, _data, getCachedCounts(),
567+
combinedDefaultTuple);
582568
}
583569

584570
@Override

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java renamed to src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCBind.java

Lines changed: 129 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,60 @@
2121

2222
import java.util.ArrayList;
2323
import java.util.List;
24+
import java.util.concurrent.ExecutionException;
25+
import java.util.concurrent.ExecutorService;
26+
import java.util.concurrent.Future;
2427

2528
import org.apache.commons.logging.Log;
2629
import org.apache.commons.logging.LogFactory;
2730
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
2831
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
32+
import org.apache.sysds.runtime.compress.DMLCompressionException;
2933
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
3034
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
3135
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
3236
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
3337
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
3438
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
39+
import org.apache.sysds.runtime.util.CommonThreadPool;
3540

36-
public final class CLALibAppend {
41+
public final class CLALibCBind {
3742

38-
private CLALibAppend(){
43+
private CLALibCBind() {
3944
// private constructor.
4045
}
4146

42-
private static final Log LOG = LogFactory.getLog(CLALibAppend.class.getName());
47+
private static final Log LOG = LogFactory.getLog(CLALibCBind.class.getName());
4348

44-
public static MatrixBlock append(MatrixBlock left, MatrixBlock right, int k) {
49+
public static MatrixBlock cbind(MatrixBlock left, MatrixBlock[] right, int k) {
50+
try {
51+
52+
if(right.length == 1) {
53+
return cbind(left, right[0], k);
54+
}
55+
else {
56+
boolean allCompressed = true;
57+
for(int i = 0; i < right.length && allCompressed; i++)
58+
allCompressed = right[i] instanceof CompressedMatrixBlock;
59+
if(allCompressed)
60+
return cbindAllCompressed((CompressedMatrixBlock) left, right, k);
61+
else
62+
return cbindAllNormalCompressed(left, right, k);
63+
}
64+
}
65+
catch(Exception e) {
66+
throw new DMLCompressionException("Failed to Cbind with compressed input", e);
67+
}
68+
}
69+
70+
private static MatrixBlock cbindAllNormalCompressed(MatrixBlock left, MatrixBlock[] right, int k) {
71+
for(int i = 0; i < right.length; i++) {
72+
left = cbind(left, right[i], k);
73+
}
74+
return left;
75+
}
76+
77+
public static MatrixBlock cbind(MatrixBlock left, MatrixBlock right, int k) {
4578

4679
final int m = left.getNumRows();
4780
final int n = left.getNumColumns() + right.getNumColumns();
@@ -66,15 +99,96 @@ else if(right.isEmpty() && left instanceof CompressedMatrixBlock)
6699
final double spar = (left.getNonZeros() + right.getNonZeros()) / ((double) m * n);
67100
final double estSizeUncompressed = MatrixBlock.estimateSizeInMemory(m, n, spar);
68101
final double estSizeCompressed = left.getInMemorySize() + right.getInMemorySize();
102+
// if(isAligned((CompressedMatrixBlock) left, (CompressedMatrixBlock) right))
103+
// return combineCompressed((CompressedMatrixBlock) left, (CompressedMatrixBlock) right);
104+
// else
69105
if(estSizeUncompressed < estSizeCompressed)
70106
return uc(left).append(uc(right), null);
71107
else if(left instanceof CompressedMatrixBlock)
72108
return appendRightUncompressed((CompressedMatrixBlock) left, right, m, n);
73109
else
74110
return appendLeftUncompressed(left, (CompressedMatrixBlock) right, m, n);
75111
}
112+
if(isAligned((CompressedMatrixBlock) left, (CompressedMatrixBlock) right))
113+
return combineCompressed((CompressedMatrixBlock) left, (CompressedMatrixBlock) right);
114+
else
115+
return append((CompressedMatrixBlock) left, (CompressedMatrixBlock) right, m, n);
116+
}
117+
118+
private static MatrixBlock cbindAllCompressed(CompressedMatrixBlock left, MatrixBlock[] right, int k)
119+
throws InterruptedException, ExecutionException {
120+
121+
final int nCol = left.getNumColumns();
122+
for(int i = 0; i < right.length; i++) {
123+
CompressedMatrixBlock rightCM = ((CompressedMatrixBlock) right[i]);
124+
if(nCol != right[i].getNumColumns() || !isAligned(left, rightCM))
125+
return cbindAllNormalCompressed(left, right, k);
126+
}
127+
return cbindAllCompressedAligned(left, right, k);
128+
129+
}
130+
131+
private static boolean isAligned(CompressedMatrixBlock left, CompressedMatrixBlock right) {
132+
final List<AColGroup> gl = left.getColGroups();
133+
for(int j = 0; j < gl.size(); j++) {
134+
final AColGroup glj = gl.get(j);
135+
final int aColumnInGroup = glj.getColIndices().get(0);
136+
final AColGroup grj = right.getColGroupForColumn(aColumnInGroup);
137+
138+
if(!glj.sameIndexStructure(grj) || glj.getNumCols() != grj.getNumCols())
139+
return false;
140+
141+
}
142+
return true;
143+
}
144+
145+
private static CompressedMatrixBlock combineCompressed(CompressedMatrixBlock left, CompressedMatrixBlock right) {
146+
final List<AColGroup> gl = left.getColGroups();
147+
final List<AColGroup> retCG = new ArrayList<>(gl.size());
148+
for(int j = 0; j < gl.size(); j++) {
149+
AColGroup glj = gl.get(j);
150+
int aColumnInGroup = glj.getColIndices().get(0);
151+
AColGroup grj = right.getColGroupForColumn(aColumnInGroup);
152+
// parallel combine...
153+
retCG.add(glj.combineWithSameIndex(left.getNumRows(), left.getNumColumns(), grj));
154+
}
155+
return new CompressedMatrixBlock(left.getNumRows(), left.getNumColumns() + right.getNumColumns(),
156+
left.getNonZeros() + right.getNonZeros(), false, retCG);
157+
}
158+
159+
private static CompressedMatrixBlock cbindAllCompressedAligned(CompressedMatrixBlock left, MatrixBlock[] right,
160+
final int k) throws InterruptedException, ExecutionException {
161+
162+
final ExecutorService pool = CommonThreadPool.get(k);
163+
try {
164+
final List<AColGroup> gl = left.getColGroups();
165+
final List<Future<AColGroup>> tasks = new ArrayList<>();
166+
final int nCol = left.getNumColumns();
167+
final int nRow = left.getNumRows();
168+
for(int i = 0; i < gl.size(); i++) {
169+
final AColGroup gli = gl.get(i);
170+
tasks.add(pool.submit(() -> {
171+
List<AColGroup> combines = new ArrayList<>();
172+
final int cId = gli.getColIndices().get(0);
173+
for(int j = 0; j < right.length; j++) {
174+
combines.add(((CompressedMatrixBlock) right[j]).getColGroupForColumn(cId));
175+
}
176+
return gli.combineWithSameIndex(nRow, nCol, combines);
177+
}));
178+
}
179+
180+
final List<AColGroup> retCG = new ArrayList<>(gl.size());
181+
for(Future<AColGroup> t : tasks)
182+
retCG.add(t.get());
183+
184+
int totalCol = nCol + right.length * nCol;
185+
186+
return new CompressedMatrixBlock(left.getNumRows(), totalCol, -1, false, retCG);
187+
}
188+
finally {
189+
pool.shutdown();
190+
}
76191

77-
return append((CompressedMatrixBlock) left, (CompressedMatrixBlock) right, m, n);
78192
}
79193

80194
private static MatrixBlock appendLeftUncompressed(MatrixBlock left, CompressedMatrixBlock right, final int m,
@@ -123,17 +237,17 @@ private static MatrixBlock append(CompressedMatrixBlock left, CompressedMatrixBl
123237
ret.setNonZeros(left.getNonZeros() + right.getNonZeros());
124238
ret.setOverlapping(left.isOverlapping() || right.isOverlapping());
125239

126-
final double compressedSize = ret.getInMemorySize();
127-
final double uncompressedSize = MatrixBlock.estimateSizeInMemory(m, n, ret.getSparsity());
240+
// final double compressedSize = ret.getInMemorySize();
241+
// final double uncompressedSize = MatrixBlock.estimateSizeInMemory(m, n, ret.getSparsity());
128242

129-
if(compressedSize < uncompressedSize)
130-
return ret;
131-
else {
132-
final double ratio = uncompressedSize / compressedSize;
133-
String message = String.format("Decompressing c bind matrix because it had to small compression ratio: %2.3f",
134-
ratio);
135-
return ret.getUncompressed(message);
136-
}
243+
// if(compressedSize < uncompressedSize)
244+
return ret;
245+
// else {
246+
// final double ratio = uncompressedSize / compressedSize;
247+
// String message = String.format("Decompressing c bind matrix because it had to small compression ratio: %2.3f",
248+
// ratio);
249+
// return ret.getUncompressed(message);
250+
// }
137251
}
138252

139253
private static MatrixBlock appendRightEmpty(CompressedMatrixBlock left, MatrixBlock right, int m, int n) {

src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixAppendCPInstruction.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.apache.commons.lang3.tuple.Pair;
2323
import org.apache.sysds.runtime.DMLRuntimeException;
2424
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
25-
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
25+
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
2626
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2727
import org.apache.sysds.runtime.lineage.LineageItem;
2828
import org.apache.sysds.runtime.lineage.LineageItemUtils;
@@ -46,8 +46,9 @@ public void processInstruction(ExecutionContext ec) {
4646
validateInput(matBlock1, matBlock2);
4747

4848
MatrixBlock ret;
49-
if(matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock)
50-
ret = CLALibAppend.append(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism());
49+
if(_type == AppendType.CBIND &&
50+
(matBlock1 instanceof CompressedMatrixBlock || matBlock2 instanceof CompressedMatrixBlock))
51+
ret = CLALibCBind.cbind(matBlock1, matBlock2, InfrastructureAnalyzer.getLocalParallelism());
5152
else
5253
ret = matBlock1.append(matBlock2, new MatrixBlock(), _type == AppendType.CBIND);
5354

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
5757
import org.apache.sysds.runtime.compress.DMLCompressionException;
5858
import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
59+
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
5960
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
6061
import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp;
6162
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
@@ -3654,10 +3655,19 @@ public final MatrixBlock append(MatrixBlock that, MatrixBlock ret ) {
36543655
return append(that, ret, true); //default cbind
36553656
}
36563657

3657-
public static MatrixBlock append(List<MatrixBlock> that,MatrixBlock ret, boolean cbind, int k ){
3658-
MatrixBlock[] th = new MatrixBlock[that.size() -1];
3659-
for(int i = 0; i < that.size() -1; i++)
3660-
th[i] = that.get(i+1);
3658+
/**
3659+
* Append that list of matrixblocks to this.
3660+
*
3661+
* @param that That list.
3662+
* @param ret The output block
3663+
* @param cbind If the blocks a appended cbind
3664+
* @param k the parallelization degree
3665+
* @return the appended matrix.
3666+
*/
3667+
public static MatrixBlock append(List<MatrixBlock> that, MatrixBlock ret, boolean cbind, int k) {
3668+
MatrixBlock[] th = new MatrixBlock[that.size() - 1];
3669+
for(int i = 0; i < that.size() - 1; i++)
3670+
th[i] = that.get(i + 1);
36613671
return that.get(0).append(th, ret, cbind);
36623672
}
36633673

@@ -3716,6 +3726,13 @@ private final int computeNNzRow(MatrixBlock[] that, int row) {
37163726
public MatrixBlock append(MatrixBlock[] that, MatrixBlock result, boolean cbind) {
37173727
checkDimensionsForAppend(that, cbind);
37183728

3729+
for(int k = 0; k < that.length; k++)
3730+
if( that[k] instanceof CompressedMatrixBlock){
3731+
if(that.length == 1 && cbind)
3732+
return CLALibCBind.cbind(this, that[0], 1);
3733+
that[k] = CompressedMatrixBlock.getUncompressed(that[k], "Append N");
3734+
}
3735+
37193736
final int m = cbind ? rlen : combinedRows(that);
37203737
final int n = cbind ? combinedCols(that) : clen;
37213738
final long nnz = calculateCombinedNNz(that);

src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
3838
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
3939
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
40+
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
4041
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
4142
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
4243
import org.apache.sysds.test.TestUtils;
@@ -395,4 +396,10 @@ public void manyRowsButNotQuite() {
395396
MatrixBlock m2 = CompressedMatrixBlockFactory.compress(m1).getLeft();
396397
TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
397398
}
399+
400+
401+
@Test(expected = Exception.class)
402+
public void cbindWithError(){
403+
CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
404+
}
398405
}

0 commit comments

Comments
 (0)