Skip to content

Commit 2ce1910

Browse files
committed
[SYSTEMDS-3771] Compressed Identity Dictionary and Selection Multiply
This commit contains the implementation details on LLM refinements for supporting the new Identity dictionaries, that remove the need for many of the matrix multiplications. Furthermore it also contains the implementation details and optimizations for selective Matrix Multiplications of matrices in the left side containing only a single 1 in each row. The implementation there simply decompress the rows associated with the 1, making the overall compressed operation very efficient. The overall implementation further improves the code-coverage of the project by 0.23% Closes #2084
1 parent eea4afc commit 2ce1910

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+3947
-591
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,11 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype
12451245
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
12461246
}
12471247

1248+
@Override
1249+
public MatrixBlock transpose(int k) {
1250+
return getUncompressed().transpose(k);
1251+
}
1252+
12481253
@Override
12491254
public String toString() {
12501255
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.commons.lang3.NotImplementedException;
2828
import org.apache.commons.logging.Log;
2929
import org.apache.commons.logging.LogFactory;
30+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
3031
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
3132
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult;
3233
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
@@ -728,6 +729,44 @@ public AColGroup sortColumnIndexes() {
728729
*/
729730
public abstract AColGroup reduceCols();
730731

732+
/**
733+
* Selection (left matrix multiply)
734+
*
735+
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
736+
* @param points The coordinates in the selection matrix to extract.
737+
* @param ret The MatrixBlock to decompress the selected rows into
738+
* @param rl The row to start at in the selection matrix
739+
* @param ru the row to end at in the selection matrix (not inclusive)
740+
*/
741+
public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
742+
if(ret.isInSparseFormat())
743+
sparseSelection(selection, points, ret, rl, ru);
744+
else
745+
denseSelection(selection, points, ret, rl, ru);
746+
}
747+
748+
/**
749+
* Sparse selection (left matrix multiply)
750+
*
751+
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
752+
* @param points The coordinates in the selection matrix to extract.
753+
* @param ret The Sparse MatrixBlock to decompress the selected rows into
754+
* @param rl The row to start at in the selection matrix
755+
* @param ru the row to end at in the selection matrix (not inclusive)
756+
*/
757+
protected abstract void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru);
758+
759+
/**
760+
* Dense selection (left matrix multiply)
761+
*
762+
* @param selection A sparse matrix with "max" a single one in each row all other values are zero.
763+
* @param points The coordinates in the selection matrix to extract.
764+
* @param ret The Dense MatrixBlock to decompress the selected rows into
765+
* @param rl The row to start at in the selection matrix
766+
* @param ru the row to end at in the selection matrix (not inclusive)
767+
*/
768+
protected abstract void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru);
769+
731770
@Override
732771
public String toString() {
733772
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ else if(lhs instanceof ColGroupUncompressed)
8585
* @return A aggregate dictionary
8686
*/
8787
public final IDictionary preAggregateThatIndexStructure(APreAgg that) {
88-
final long outputLength = (long)that._colIndexes.size() * this.getNumValues();
88+
final long outputLength = (long) that._colIndexes.size() * this.getNumValues();
8989
if(outputLength > Integer.MAX_VALUE)
9090
throw new NotImplementedException("Not supported pre aggregate of above integer length");
9191
if(outputLength <= 0) // if the pre aggregate output is empty or nothing, return null
9292
return null;
93-
93+
9494
// create empty Dictionary that we slowly fill, hence the dictionary is empty and no check
95-
final Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]);
95+
final Dictionary ret = Dictionary.createNoCheck(new double[(int) outputLength]);
9696

9797
if(that instanceof ColGroupDDC)
9898
preAggregateThatDDCStructure((ColGroupDDC) that, ret);
@@ -119,7 +119,7 @@ else if(that instanceof ColGroupRLE)
119119
*/
120120
public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
121121
if(m.isInSparseFormat())
122-
preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru);
122+
preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru, 0, m.getNumColumns());
123123
else
124124
preAggregateDense(m, preAgg, rl, ru, 0, m.getNumColumns());
125125
}
@@ -136,7 +136,7 @@ public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) {
136136
*/
137137
public abstract void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu);
138138

139-
public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru);
139+
public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu);
140140

141141
protected abstract void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret);
142142

@@ -160,11 +160,13 @@ private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) {
160160
final boolean left = shouldPreAggregateLeft(lg);
161161
if(!loggedWarningForDirect && shouldDirectMultiply(lg, leftIdx.size(), rightIdx.size(), left)) {
162162
loggedWarningForDirect = true;
163-
LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + this.getClass().getSimpleName() );
163+
LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% "
164+
+ this.getClass().getSimpleName());
164165
}
165166

166167
if(left) {
167168
final IDictionary lpa = this.preAggregateThatIndexStructure(lg);
169+
168170
if(lpa != null)
169171
DictLibMatrixMult.TSMMToUpperTriangle(lpa, _dict, leftIdx, rightIdx, result);
170172
}
@@ -222,7 +224,7 @@ else if(shouldPreAggregateLeft(lhs)) {// left preAgg
222224
DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result);
223225
}
224226
else {// right preAgg
225-
final IDictionary rhsPA = preAggregateThatIndexStructure(lhs);
227+
final IDictionary rhsPA = this.preAggregateThatIndexStructure(lhs);
226228
if(rhsPA != null)
227229
DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result);
228230
}
@@ -311,17 +313,20 @@ public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock
311313
// Shallow copy the preAgg to allow sparse PreAgg multiplication but do not remove the original dense allocation
312314
// since the dense allocation is reused.
313315
final MatrixBlock preAggCopy = new MatrixBlock();
314-
preAggCopy.copy(preAgg);
316+
preAggCopy.copyShallow(preAgg);
315317
final MatrixBlock tmpResCopy = new MatrixBlock();
316-
tmpResCopy.copy(tmpRes);
318+
tmpResCopy.copyShallow(tmpRes);
317319
// Get dictionary matrixBlock
318320
final MatrixBlock dict = getDictionary().getMBDict(_colIndexes.size()).getMatrixBlock();
319321
if(dict != null) {
320322
// Multiply
321-
LibMatrixMult.matrixMult(preAggCopy, dict, tmpResCopy, k);
322-
ColGroupUtils.addMatrixToResult(tmpResCopy, ret, _colIndexes, rl, ru);
323+
LibMatrixMult.matrixMult(preAggCopy, dict, tmpRes, k);
324+
ColGroupUtils.addMatrixToResult(tmpRes, ret, _colIndexes, rl, ru);
323325
}
324326
}
325327

326328
protected abstract int numRowsToMultiply();
329+
330+
public abstract void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl,
331+
int cu);
327332
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import org.apache.commons.lang3.NotImplementedException;
2626
import org.apache.sysds.runtime.compress.DMLCompressionException;
27+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
2728
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
2829
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
2930
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -647,4 +648,14 @@ public AMapToData getMapToData() {
647648
return MapToFactory.create(0, 0);
648649
}
649650

651+
@Override
652+
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
653+
throw new NotImplementedException();
654+
}
655+
656+
@Override
657+
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
658+
throw new NotImplementedException();
659+
}
660+
650661
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
import org.apache.sysds.runtime.DMLRuntimeException;
2929
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
3030
import org.apache.sysds.runtime.compress.DMLCompressionException;
31+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
3132
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
3233
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
3334
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
35+
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
3436
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
3537
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
3638
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
39+
import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex;
3740
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
3841
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToByte;
3942
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar;
@@ -398,7 +401,10 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in
398401
}
399402

400403
@Override
401-
public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) {
404+
public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) {
405+
if(cl != 0 || cu != _data.size()) {
406+
throw new NotImplementedException();
407+
}
402408
_data.preAggregateSparse(sb, preAgg, rl, ru);
403409
}
404410

@@ -628,6 +634,90 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
628634
return ColGroupDDC.create(newColIndex, _dict.reorder(reordering), _data, getCachedCounts());
629635
}
630636

637+
@Override
638+
public void sparseSelection(MatrixBlock selection,P[] points, MatrixBlock ret, int rl, int ru) {
639+
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
640+
final SparseBlock sb = selection.getSparseBlock();
641+
final SparseBlock retB = ret.getSparseBlock();
642+
for(int r = rl; r < ru; r++) {
643+
if(sb.isEmpty(r))
644+
continue;
645+
final int sPos = sb.pos(r);
646+
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
647+
decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
648+
}
649+
}
650+
651+
652+
@Override
653+
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
654+
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
655+
final SparseBlock sb = selection.getSparseBlock();
656+
final DenseBlock retB = ret.getDenseBlock();
657+
for(int r = rl; r < ru; r++) {
658+
if(sb.isEmpty(r))
659+
continue;
660+
final int sPos = sb.pos(r);
661+
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
662+
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
663+
}
664+
}
665+
666+
@Override
667+
public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) {
668+
DenseBlock db = that.getDenseBlock();
669+
DenseBlock retDB = ret.getDenseBlock();
670+
if(rl == ru - 1)
671+
leftMMIdentityPreAggregateDenseSingleRow(db.values(rl), db.pos(rl), retDB.values(rl), retDB.pos(rl), cl, cu);
672+
else
673+
throw new NotImplementedException();
674+
}
675+
676+
677+
private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, double[] values2, int pos2, int cl,
678+
int cu) {
679+
IdentityDictionary a = (IdentityDictionary) _dict;
680+
if(_colIndexes instanceof RangeIndex)
681+
leftMMIdentityPreAggregateDenseSingleRowRangeIndex(values, pos, values2, pos2, cl, cu);
682+
else {
683+
684+
pos += cl; // left side matrix position offset.
685+
if(a.withEmpty()) {
686+
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;
687+
for(int rc = cl; rc < cu; rc++, pos++) {
688+
final int idx = _data.getIndex(rc);
689+
if(idx != nVal)
690+
values2[_colIndexes.get(idx)] += values[pos];
691+
}
692+
}
693+
else {
694+
for(int rc = cl; rc < cu; rc++, pos++)
695+
values2[_colIndexes.get(_data.getIndex(rc))] += values[pos];
696+
}
697+
}
698+
}
699+
700+
701+
private void leftMMIdentityPreAggregateDenseSingleRowRangeIndex(double[] values, int pos, double[] values2, int pos2,
702+
int cl, int cu) {
703+
IdentityDictionary a = (IdentityDictionary) _dict;
704+
705+
final int firstCol = _colIndexes.get(0);
706+
pos += cl; // left side matrix position offset.
707+
if(a.withEmpty()) {
708+
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;
709+
for(int rc = cl; rc < cu; rc++, pos++) {
710+
final int idx = _data.getIndex(rc);
711+
if(idx != nVal)
712+
values2[firstCol + idx] += values[pos];
713+
}
714+
}
715+
else {
716+
for(int rc = cl; rc < cu; rc++, pos++)
717+
values2[firstCol + _data.getIndex(rc)] += values[pos];
718+
}
719+
}
720+
631721
@Override
632722
public String toString() {
633723
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import org.apache.commons.lang3.NotImplementedException;
2828
import org.apache.sysds.runtime.DMLRuntimeException;
29+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
2930
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
3031
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
3132
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
@@ -40,6 +41,8 @@
4041
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
4142
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
4243
import org.apache.sysds.runtime.compress.utils.Util;
44+
import org.apache.sysds.runtime.data.DenseBlock;
45+
import org.apache.sysds.runtime.data.SparseBlock;
4346
import org.apache.sysds.runtime.functionobjects.Builtin;
4447
import org.apache.sysds.runtime.functionobjects.Divide;
4548
import org.apache.sysds.runtime.functionobjects.Minus;
@@ -252,7 +255,7 @@ public AColGroup replace(double pattern, double replace) {
252255
if(patternInReference) {
253256
double[] nRef = new double[_reference.length];
254257
for(int i = 0; i < _reference.length; i++)
255-
if(Util.eq(pattern ,_reference[i]))
258+
if(Util.eq(pattern, _reference[i]))
256259
nRef[i] = replace;
257260
else
258261
nRef[i] = _reference[i];
@@ -489,6 +492,34 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
489492
throw new NotImplementedException();
490493
}
491494

495+
@Override
496+
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
497+
final SparseBlock sb = selection.getSparseBlock();
498+
final SparseBlock retB = ret.getSparseBlock();
499+
for(int r = rl; r < ru; r++) {
500+
if(sb.isEmpty(r))
501+
continue;
502+
503+
final int sPos = sb.pos(r);
504+
final int rowCompressed = sb.indexes(r)[sPos];
505+
decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
506+
}
507+
}
508+
509+
@Override
510+
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
511+
final SparseBlock sb = selection.getSparseBlock();
512+
final DenseBlock retB = ret.getDenseBlock();
513+
for(int r = rl; r < ru; r++) {
514+
if(sb.isEmpty(r))
515+
continue;
516+
517+
final int sPos = sb.pos(r);
518+
final int rowCompressed = sb.indexes(r)[sPos];
519+
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
520+
}
521+
}
522+
492523
@Override
493524
public String toString() {
494525
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import java.io.IOException;
2424
import java.util.Arrays;
2525

26+
import org.apache.commons.lang3.NotImplementedException;
2627
import org.apache.sysds.runtime.DMLRuntimeException;
2728
import org.apache.sysds.runtime.compress.DMLCompressionException;
29+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
2830
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
2931
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
3032
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
@@ -53,7 +55,7 @@
5355
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
5456

5557
public class ColGroupEmpty extends AColGroupCompressed
56-
implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup ,IMapToDataGroup{
58+
implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup, IMapToDataGroup {
5759
private static final long serialVersionUID = -2307677253622099958L;
5860

5961
/**
@@ -403,9 +405,18 @@ public AMapToData getMapToData() {
403405
return MapToFactory.create(0, 0);
404406
}
405407

406-
@Override
407-
public AColGroup reduceCols(){
408+
@Override
409+
public AColGroup reduceCols() {
408410
return null;
409411
}
410412

413+
@Override
414+
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
415+
throw new NotImplementedException();
416+
}
417+
418+
@Override
419+
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
420+
throw new NotImplementedException();
421+
}
411422
}

0 commit comments

Comments
 (0)