Skip to content

Commit 89623cc

Browse files
committed
Add Columngroups Update
1 parent 1fc1499 commit 89623cc

30 files changed

+2274
-357
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java

Lines changed: 182 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,27 @@
2323
import java.io.IOException;
2424
import java.io.Serializable;
2525
import java.util.Collection;
26+
import java.util.List;
27+
import java.util.concurrent.ExecutorService;
2628

2729
import org.apache.commons.lang3.NotImplementedException;
2830
import org.apache.commons.logging.Log;
2931
import org.apache.commons.logging.LogFactory;
3032
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
33+
import org.apache.sysds.runtime.compress.CompressionSettings;
34+
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
3135
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
3236
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
3337
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult;
3438
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
3539
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
40+
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
3641
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
3742
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
3843
import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups;
3944
import org.apache.sysds.runtime.data.DenseBlock;
4045
import org.apache.sysds.runtime.data.SparseBlock;
46+
import org.apache.sysds.runtime.data.SparseBlockMCSR;
4147
import org.apache.sysds.runtime.functionobjects.Plus;
4248
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
4349
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -165,14 +171,32 @@ public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru) {
165171
/**
166172
* Decompress a range of rows into a dense block
167173
*
168-
* @param db Sparse Target block
174+
* @param db Dense target block
169175
* @param rl Row to start at
170176
* @param ru Row to end at
171177
*/
172178
public final void decompressToDenseBlock(DenseBlock db, int rl, int ru) {
173179
decompressToDenseBlock(db, rl, ru, 0, 0);
174180
}
175181

182+
/**
183+
* Decompress a range of rows into a dense transposed block.
184+
*
185+
* @param db Dense target block
186+
* @param rl Row in this column group to start at.
187+
* @param ru Row in this column group to end at.
188+
*/
189+
public abstract void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru);
190+
191+
/**
192+
* Decompress the column group to the sparse transposed block. Note that the column groups would only need to
193+
* decompress into specific sub rows of the Sparse block
194+
*
195+
* @param sb Sparse target block
196+
* @param nColOut The number of columns in the sb.
197+
*/
198+
public abstract void decompressToSparseBlockTransposed(SparseBlockMCSR sb, int nColOut);
199+
176200
/**
177201
* Serializes column group to data output.
178202
*
@@ -321,7 +345,7 @@ public double get(int r, int c) {
321345
*
322346
* @param db Target DenseBlock
323347
* @param rl Row to start decompression from
324-
* @param ru Row to end decompression at
348+
* @param ru Row to end decompression at (not inclusive)
325349
* @param offR Row offset into the target to decompress
326350
* @param offC Column offset into the target to decompress
327351
*/
@@ -335,7 +359,7 @@ public double get(int r, int c) {
335359
*
336360
* @param sb Target SparseBlock
337361
* @param rl Row to start decompression from
338-
* @param ru Row to end decompression at
362+
* @param ru Row to end decompression at (not inclusive)
339363
* @param offR Row offset into the target to decompress
340364
* @param offC Column offset into the target to decompress
341365
*/
@@ -350,7 +374,7 @@ public double get(int r, int c) {
350374
* @return The new Column Group or null that is the result of the matrix multiplication.
351375
*/
352376
public final AColGroup rightMultByMatrix(MatrixBlock right) {
353-
return rightMultByMatrix(right, null);
377+
return rightMultByMatrix(right, null, 1);
354378
}
355379

356380
/**
@@ -361,9 +385,25 @@ public final AColGroup rightMultByMatrix(MatrixBlock right) {
361385
* @param right The MatrixBlock on the right of this matrix multiplication
362386
* @param allCols A pre-materialized list of all col indexes, that can be shared across all column groups if use
363387
* full, can be set to null.
388+
* @param k The parallelization degree allowed internally in this operation.
364389
* @return The new Column Group or null that is the result of the matrix multiplication.
365390
*/
366-
public abstract AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols);
391+
public abstract AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols, int k);
392+
393+
/**
394+
* Right side Matrix multiplication, iterating though this column group and adding to the ret
395+
*
396+
* @param right Right side matrix to multiply with.
397+
* @param ret The return matrix to add results to
398+
* @param rl The row of this column group to multiply from
399+
* @param ru The row of this column group to multiply to (not inclusive)
400+
* @param crl The right hand side column lower
401+
* @param cru The right hand side column upper
402+
* @param nRows The number of rows in this column group
403+
*/
404+
public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru){
405+
throw new NotImplementedException("not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName());
406+
}
367407

368408
/**
369409
* Do a transposed self matrix multiplication on the left side t(x) %*% x. but only with this column group.
@@ -766,7 +806,7 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo
766806
else
767807
denseSelection(selection, points, ret, rl, ru);
768808
}
769-
809+
770810
/**
771811
* Get an approximate sparsity of this column group
772812
*
@@ -796,6 +836,142 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo
796836
*/
797837
protected abstract void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru);
798838

839+
/**
840+
* Method to determine if the columnGroup have the same index structure as another. Note that the column indexes and
841+
* dictionaries are allowed to be different.
842+
*
843+
* @param that the other column group
844+
* @return if the index is the same.
845+
*/
846+
public boolean sameIndexStructure(AColGroup that) {
847+
return false;
848+
}
849+
850+
/**
851+
* C bind the list of column groups with this column group. the list of elements provided in the index of each list
852+
* is guaranteed to have the same index structures
853+
*
854+
* @param nRow The number of rows contained in all right and this column group.
855+
* @param nCol The number of columns to shift the right hand side column groups over when combining, this should
856+
* only effect the column indexes
857+
* @param right The right hand side column groups to combine. NOTE only the index offset of the second nested list
858+
* should be used. The reason for providing this nested list is to avoid redundant allocations in
859+
* calling methods.
860+
* @return A combined compressed column group of the same type as this!.
861+
*/
862+
public AColGroup combineWithSameIndex(int nRow, int nCol, List<AColGroup> right) {
863+
// default decompress... nasty !
864+
865+
IColIndex combinedColIndex = combineColIndexes(nCol, right);
866+
867+
MatrixBlock decompressTarget = new MatrixBlock(nRow, combinedColIndex.size(), false);
868+
869+
decompressTarget.allocateDenseBlock();
870+
DenseBlock db = decompressTarget.getDenseBlock();
871+
final int nColInThisGroup = _colIndexes.size();
872+
this.copyAndSet(ColIndexFactory.create(nColInThisGroup)).decompressToDenseBlock(db, 0, nRow);
873+
874+
for(int i = 0; i < right.size(); i++) {
875+
right.get(i).copyAndSet(ColIndexFactory.create(i * nColInThisGroup, i * nColInThisGroup + nColInThisGroup))
876+
.decompressToDenseBlock(db, 0, nRow);
877+
}
878+
879+
decompressTarget.setNonZeros(nRow * combinedColIndex.size());
880+
881+
CompressedSizeInfoColGroup ci = new CompressedSizeInfoColGroup(ColIndexFactory.create(combinedColIndex.size()),
882+
nRow, nRow, CompressionType.DDC);
883+
CompressedSizeInfo csi = new CompressedSizeInfo(ci);
884+
885+
CompressionSettings cs = new CompressionSettingsBuilder().create();
886+
return ColGroupFactory.compressColGroups(decompressTarget, csi, cs).get(0).copyAndSet(combinedColIndex);
887+
}
888+
889+
/**
890+
* C bind the given column group to this.
891+
*
892+
* @param nRow The number of rows contained in the right and this column group.
893+
* @param nCol The number of columns in this.
894+
* @param right The column group to c-bind.
895+
* @return a new combined column groups.
896+
*/
897+
public AColGroup combineWithSameIndex(int nRow, int nCol, AColGroup right) {
898+
899+
IColIndex combinedColIndex = _colIndexes.combine(right._colIndexes.shift(nCol));
900+
901+
MatrixBlock decompressTarget = new MatrixBlock(nRow, combinedColIndex.size(), false);
902+
903+
decompressTarget.allocateDenseBlock();
904+
DenseBlock db = decompressTarget.getDenseBlock();
905+
final int nColInThisGroup = _colIndexes.size();
906+
this.copyAndSet(ColIndexFactory.create(nColInThisGroup)).decompressToDenseBlock(db, 0, nRow);
907+
908+
right.copyAndSet(ColIndexFactory.create(nColInThisGroup, nColInThisGroup + nColInThisGroup))
909+
.decompressToDenseBlock(db, 0, nRow);
910+
911+
decompressTarget.setNonZeros(nRow * combinedColIndex.size());
912+
913+
CompressedSizeInfoColGroup ci = new CompressedSizeInfoColGroup(ColIndexFactory.create(combinedColIndex.size()),
914+
nRow, nRow, CompressionType.DDC);
915+
CompressedSizeInfo csi = new CompressedSizeInfo(ci);
916+
917+
CompressionSettings cs = new CompressionSettingsBuilder().create();
918+
return ColGroupFactory.compressColGroups(decompressTarget, csi, cs).get(0).copyAndSet(combinedColIndex);
919+
// throw new NotImplementedException("Combine of : " + this.getClass().getSimpleName() + " not implemented");
920+
}
921+
922+
protected IColIndex combineColIndexes(final int nCol, List<AColGroup> right) {
923+
IColIndex combinedColIndex = _colIndexes;
924+
for(int i = 0; i < right.size(); i++)
925+
combinedColIndex = combinedColIndex.combine(right.get(i).getColIndices().shift(nCol * i + nCol));
926+
return combinedColIndex;
927+
}
928+
929+
/**
930+
* This method returns a list of column groups that are naive splits of this column group as if it is reshaped.
931+
*
932+
* This means the column groups rows are split into x number of other column groups where x is the multiplier.
933+
*
934+
* The indexes are assigned round robbin to each of the output groups, meaning the first index is assigned.
935+
*
936+
* If for instance the 4. column group is split by a 2 multiplier and there was 5 columns in total originally. The
937+
* output becomes 2 column groups at column index 4 and one at 9.
938+
*
939+
* If possible the split column groups should reuse pointers back to the original dictionaries!
940+
*
941+
* @param multiplier The number of column groups to split into
942+
* @param nRow The number of rows in this column group in case the underlying column group does not know
943+
* @param nColOrg The number of overall columns in the host CompressedMatrixBlock.
944+
* @return a list of split column groups
945+
*/
946+
public abstract AColGroup[] splitReshape(final int multiplier, final int nRow, final int nColOrg);
947+
948+
/**
949+
* This method returns a list of column groups that are naive splits of this column group as if it is reshaped.
950+
*
951+
* This means the column groups rows are split into x number of other column groups where x is the multiplier.
952+
*
953+
* The indexes are assigned round robbin to each of the output groups, meaning the first index is assigned.
954+
*
955+
* If for instance the 4. column group is split by a 2 multiplier and there was 5 columns in total originally. The
956+
* output becomes 2 column groups at column index 4 and one at 9.
957+
*
958+
* If possible the split column groups should reuse pointers back to the original dictionaries!
959+
*
960+
* This specific variation is pushing down the parallelization given via the executor service provided. If not
961+
* overwritten the default is to call the normal split reshape
962+
*
963+
* @param multiplier The number of column groups to split into
964+
* @param nRow The number of rows in this column group in case the underlying column group does not know
965+
* @param nColOrg The number of overall columns in the host CompressedMatrixBlock
966+
* @param pool The executor service to submit parallel tasks to
967+
* @throws Exception In case there is an error we throw the exception out instead of handling it
968+
* @return a list of split column groups
969+
*/
970+
public AColGroup[] splitReshapePushDown(final int multiplier, final int nRow, final int nColOrg,
971+
final ExecutorService pool) throws Exception {
972+
return splitReshape(multiplier, nRow, nColOrg);
973+
}
974+
799975
@Override
800976
public String toString() {
801977
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ protected AColGroupCompressed(IColIndex colIndices) {
8686

8787
protected abstract double[] preAggBuiltinRows(Builtin builtin);
8888

89+
@Override
90+
public boolean sameIndexStructure(AColGroup that) {
91+
if(that instanceof AColGroupCompressed)
92+
return sameIndexStructure((AColGroupCompressed) that);
93+
else
94+
return false;
95+
}
96+
8997
public abstract boolean sameIndexStructure(AColGroupCompressed that);
9098

9199
public double[] preAggRows(ValueFunction fn) {
@@ -215,7 +223,8 @@ protected static void tsmm(double[] result, int numColumns, int[] counts, IDicti
215223

216224
}
217225

218-
protected static void tsmmDense(double[] result, int numColumns, double[] values, int[] counts, IColIndex colIndexes) {
226+
protected static void tsmmDense(double[] result, int numColumns, double[] values, int[] counts,
227+
IColIndex colIndexes) {
219228
final int nCol = colIndexes.size();
220229
final int nRow = counts.length;
221230
for(int k = 0; k < nRow; k++) {
@@ -231,7 +240,8 @@ protected static void tsmmDense(double[] result, int numColumns, double[] values
231240
}
232241
}
233242

234-
protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts, IColIndex colIndexes) {
243+
protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts,
244+
IColIndex colIndexes) {
235245
for(int row = 0; row < counts.length; row++) {
236246
if(sb.isEmpty(row))
237247
continue;

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupOffset.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,9 @@ public long getExactSizeOnDisk() {
142142
public boolean containZerosTuples() {
143143
return _zeros;
144144
}
145+
146+
@Override
147+
protected boolean allowShallowIdentityRightMult() {
148+
return true;
149+
}
145150
}

0 commit comments

Comments
 (0)