Skip to content

Commit 0ae6f28

Browse files
committed
CLALib Combine Columngroups With Morphing
1 parent 29b4d92 commit 0ae6f28

33 files changed

+2264
-567
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.commons.logging.Log;
2929
import org.apache.commons.logging.LogFactory;
3030
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
31+
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
3132
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
3233
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult;
3334
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
@@ -671,11 +672,31 @@ public void clear() {
671672
/**
672673
* Recompress this column group into a new column group of the given type.
673674
*
674-
* @param ct The compressionType that the column group should morph into
675+
* @param ct The compressionType that the column group should morph into
676+
* @param nRow The number of rows in this columngroup.
675677
* @return A new column group
676678
*/
677-
public AColGroup morph(CompressionType ct) {
678-
throw new NotImplementedException();
679+
public AColGroup morph(CompressionType ct, int nRow) {
680+
if(ct == getCompType())
681+
return this;
682+
else if(ct == CompressionType.DDCFOR)
683+
return this; // it does not make sense to change to FOR.
684+
else if(ct == CompressionType.UNCOMPRESSED) {
685+
AColGroup cgMoved = this.copyAndSet(ColIndexFactory.create(_colIndexes.size()));
686+
final long nnz = getNumberNonZeros(nRow);
687+
MatrixBlock newDict = new MatrixBlock(nRow, _colIndexes.size(), nnz);
688+
newDict.allocateBlock();
689+
if(newDict.isInSparseFormat())
690+
cgMoved.decompressToSparseBlock(newDict.getSparseBlock(), 0, nRow);
691+
else
692+
cgMoved.decompressToDenseBlock(newDict.getDenseBlock(), 0, nRow);
693+
newDict.setNonZeros(nnz);
694+
AColGroup cgUC = ColGroupUncompressed.create(newDict);
695+
return cgUC.copyAndSet(_colIndexes);
696+
}
697+
else {
698+
throw new NotImplementedException("Morphing from : " + getCompType() + " to " + ct + " is not implemented");
699+
}
679700
}
680701

681702
/**
@@ -690,10 +711,11 @@ public AColGroup morph(CompressionType ct) {
690711
* Combine this column group with another
691712
*
692713
* @param other The other column group to combine with.
714+
* @param nRow The number of rows in both column groups.
693715
* @return A combined representation as a column group.
694716
*/
695-
public AColGroup combine(AColGroup other) {
696-
return CLALibCombineGroups.combine(this, other);
717+
public AColGroup combine(AColGroup other, int nRow) {
718+
return CLALibCombineGroups.combine(this, other, nRow);
697719
}
698720

699721
/**
@@ -745,6 +767,13 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo
745767
denseSelection(selection, points, ret, rl, ru);
746768
}
747769

770+
/**
771+
* Get an approximate sparsity of this column group
772+
*
773+
* @return the approximate sparsity of this columngroup
774+
*/
775+
public abstract double getSparsity();
776+
748777
/**
749778
* Sparse selection (left matrix multiply)
750779
*

src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,9 @@ public AColGroup reduceCols(){
324324
return null;
325325
return copyAndSet(outCols, newDict);
326326
}
327-
327+
328+
@Override
329+
public double getSparsity() {
330+
return _dict.getSparsity();
331+
}
328332
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,11 @@ public AMapToData getMapToData() {
648648
return MapToFactory.create(0, 0);
649649
}
650650

651+
@Override
652+
public double getSparsity() {
653+
return 1.0;
654+
}
655+
651656
@Override
652657
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
653658
throw new NotImplementedException();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,11 @@ public AColGroup reduceCols() {
410410
return null;
411411
}
412412

413+
@Override
414+
public double getSparsity() {
415+
return 0.0;
416+
}
417+
413418
@Override
414419
protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
415420
throw new NotImplementedException();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,11 @@ public AColGroup reduceCols() {
709709
throw new NotImplementedException();
710710
}
711711

712+
@Override
713+
public double getSparsity() {
714+
return 1.0;
715+
}
716+
712717
@Override
713718
public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
714719
throw new NotImplementedException();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,11 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
907907
ret.set(r, c, _data.get(r, reordering[c]));
908908
return create(newColIndex, ret, false);
909909
}
910+
911+
@Override
912+
public double getSparsity() {
913+
return _data.getSparsity();
914+
}
910915

911916
@Override
912917
public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,93 @@ public static void correctNan(double[] res, IColIndex colIndexes) {
8787
}
8888
}
8989

90+
@Override
91+
public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns,
92+
int nColRight) {
93+
if(aggregateColumns.size() < nColRight)
94+
return rightMMPreAggSparseSelectedCols(numVals, b, thisCols, aggregateColumns);
95+
else
96+
return rightMMPreAggSparseAllColsRight(numVals, b, thisCols, nColRight);
97+
}
98+
99+
protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols,
100+
IColIndex aggregateColumns) {
101+
102+
final int thisColsSize = thisCols.size();
103+
final int aggColSize = aggregateColumns.size();
104+
final double[] ret = new double[numVals * aggColSize];
105+
106+
for(int h = 0; h < thisColsSize; h++) {
107+
// chose row in right side matrix via column index of the dictionary
108+
final int colIdx = thisCols.get(h);
109+
if(b.isEmpty(colIdx))
110+
continue;
111+
112+
// extract the row values on the right side.
113+
final double[] sValues = b.values(colIdx);
114+
final int[] sIndexes = b.indexes(colIdx);
115+
final int sPos = b.pos(colIdx);
116+
final int sEnd = b.size(colIdx) + sPos;
117+
118+
for(int j = 0; j < numVals; j++) { // rows left
119+
final int offOut = j * aggColSize;
120+
final double v = getValue(j, h, thisColsSize);
121+
sparseAddSelected(sPos, sEnd, aggColSize, aggregateColumns, sIndexes, sValues, ret, offOut, v);
122+
}
123+
124+
}
125+
return Dictionary.create(ret);
126+
}
127+
128+
private final void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex aggregateColumns, int[] sIndexes,
129+
double[] sValues, double[] ret, int offOut, double v) {
130+
131+
int retIdx = 0;
132+
for(int i = sPos; i < sEnd; i++) {
133+
// skip through the retIdx.
134+
while(retIdx < aggColSize && aggregateColumns.get(retIdx) < sIndexes[i])
135+
retIdx++;
136+
if(retIdx == aggColSize)
137+
break;
138+
ret[offOut + retIdx] += v * sValues[i];
139+
}
140+
retIdx = 0;
141+
}
142+
143+
protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols,
144+
int nColRight) {
145+
final int thisColsSize = thisCols.size();
146+
final double[] ret = new double[numVals * nColRight];
147+
148+
for(int h = 0; h < thisColsSize; h++) { // common dim
149+
// chose row in right side matrix via column index of the dictionary
150+
final int colIdx = thisCols.get(h);
151+
if(b.isEmpty(colIdx))
152+
continue;
153+
154+
// extract the row values on the right side.
155+
final double[] sValues = b.values(colIdx);
156+
final int[] sIndexes = b.indexes(colIdx);
157+
final int sPos = b.pos(colIdx);
158+
final int sEnd = b.size(colIdx) + sPos;
159+
160+
for(int i = 0; i < numVals; i++) { // rows left
161+
final int offOut = i * nColRight;
162+
final double v = getValue(i, h, thisColsSize);
163+
SparseAdd(sPos, sEnd, ret, offOut, sIndexes, sValues, v);
164+
}
165+
}
166+
return Dictionary.create(ret);
167+
}
168+
169+
private final void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) {
170+
if(v != 0) {
171+
for(int k = sPos; k < sEnd; k++) { // cols right with value
172+
ret[offOut + sIdx[k]] += v * sVals[k];
173+
}
174+
}
175+
}
176+
90177
@Override
91178
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
92179
for(int i = 0; i < nCol; i++)
@@ -101,4 +188,12 @@ public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex col
101188
dv[off + columns.get(i)] += getValue(idx, i, nCol);
102189
}
103190

191+
@Override
192+
public double[] getRow(int i, int nCol) {
193+
double[] ret = new double[nCol];
194+
for(int c = 0; c < nCol; c++) {
195+
ret[c] = getValue(i, c, nCol);
196+
}
197+
return ret;
198+
}
104199
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ else if(row > col) // swap because in lower triangle
5656
/**
5757
* Matrix multiply with scaling (left side transposed)
5858
*
59-
* @param left Left side dictionary
60-
* @param right Right side dictionary
59+
* @param left Left side dictionary that is not physically transposed but should be treated if it is.
60+
* @param right Right side dictionary that is not transposed and should be used as is.
6161
* @param leftRows Left side row offsets
6262
* @param rightColumns Right side column offsets
63-
* @param result The result matrix
63+
* @param result The result matrix, normal allocation.
6464
* @param counts The scaling factors
6565
*/
6666
public static void MMDictsWithScaling(IDictionary left, IDictionary right, IColIndex leftRows,
@@ -221,7 +221,6 @@ protected static void MMDictsScalingDenseDense(double[] left, double[] right, IC
221221
final int commonDim = Math.min(left.length / leftSide, right.length / rightSide);
222222
final int resCols = result.getNumColumns();
223223
final double[] resV = result.getDenseBlockValues();
224-
225224
for(int k = 0; k < commonDim; k++) {
226225
final int offL = k * leftSide;
227226
final int offR = k * rightSide;
@@ -305,8 +304,8 @@ protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColI
305304
}
306305
}
307306

308-
protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight,
309-
MatrixBlock result, int[] scaling) {
307+
protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft,
308+
IColIndex colsRight, MatrixBlock result, int[] scaling) {
310309
final double[] resV = result.getDenseBlockValues();
311310
final int leftSize = rowsLeft.size();
312311
final int commonDim = Math.min(left.length / leftSize, right.numRows());
@@ -538,19 +537,27 @@ else if(loc > 0)
538537

539538
protected static void MMToUpperTriangleDenseDenseAllUpperTriangle(double[] left, double[] right, IColIndex rowsLeft,
540539
IColIndex colsRight, MatrixBlock result) {
541-
final int commonDim = Math.min(left.length / rowsLeft.size(), right.length / colsRight.size());
540+
final int lSize = rowsLeft.size();
541+
final int rSize = colsRight.size();
542+
final int commonDim = Math.min(left.length / lSize, right.length / rSize);
542543
final int resCols = result.getNumColumns();
543544
final double[] resV = result.getDenseBlockValues();
545+
for(int i = 0; i < lSize; i++) {
546+
MMToUpperTriangleDenseDenseAllUpperTriangleRow(left, right, rowsLeft.get(i), colsRight, commonDim, lSize,
547+
rSize, i, resV, resCols);
548+
}
549+
}
550+
551+
protected static void MMToUpperTriangleDenseDenseAllUpperTriangleRow(final double[] left, final double[] right,
552+
final int rowOut, final IColIndex colsRight, final int commonDim, final int lSize, final int rSize, final int i,
553+
final double[] resV, final int resCols) {
544554
for(int k = 0; k < commonDim; k++) {
545-
final int offL = k * rowsLeft.size();
546-
final int offR = k * colsRight.size();
547-
for(int i = 0; i < rowsLeft.size(); i++) {
548-
final int rowOut = rowsLeft.get(i);
549-
final double vl = left[offL + i];
550-
if(vl != 0) {
551-
for(int j = 0; j < colsRight.size(); j++)
552-
resV[colsRight.get(j) * resCols + rowOut] += vl * right[offR + j];
553-
}
555+
final int offL = k * lSize;
556+
final double vl = left[offL + i];
557+
if(vl != 0) {
558+
final int offR = k * rSize;
559+
for(int j = 0; j < rSize; j++)
560+
resV[colsRight.get(j) * resCols + rowOut] += vl * right[offR + j];
554561
}
555562
}
556563
}

0 commit comments

Comments
 (0)