Skip to content

Commit 8f4dba1

Browse files
committed
[SYSTEMDS-3801] Fix missing method implementations in ColGroupSDCZeros
The previous master version broke the AWARE experiment for the kmeans+ algorithm. This patch fixes that and adds missing methods implementations for DenseBlocks in ColGroupSDCZeros. After the changes, the runtime additionally was decreased from 40s to 32s for the kmeans+ algorithm on the US Census dataset. Closes #2149.
1 parent 53cba12 commit 8f4dba1

File tree

6 files changed

+74
-12
lines changed

6 files changed

+74
-12
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
683683
}
684684
else {
685685
while(c < points.length && points[c].o == of) {
686-
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
686+
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
687687
c++;
688688
}
689689
of = it.next();
@@ -696,7 +696,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
696696
}
697697

698698
while(of == last && c < points.length && points[c].o == of) {
699-
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
699+
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
700700
c++;
701701
}
702702

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret,
836836

837837
while(of < last && c < points.length) {
838838
if(points[c].o == of) {
839-
c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
839+
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
840840
of = it.next();
841841
}
842842
else if(points[c].o < of)
@@ -848,18 +848,46 @@ else if(points[c].o < of)
848848
while(c < points.length && points[c].o < last)
849849
c++;
850850

851-
c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
851+
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
852852

853853
}
854854

855855
@Override
856856
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
857-
throw new NotImplementedException();
857+
final DenseBlock dr = ret.getDenseBlock();
858+
final int nCol = _colIndexes.size();
859+
final AIterator it = _indexes.getIterator();
860+
final int last = _indexes.getOffsetToLast();
861+
int c = 0;
862+
int of = it.value();
863+
864+
while(of < last && c < points.length) {
865+
if(points[c].o == of) {
866+
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
867+
of = it.next();
868+
}
869+
else if(points[c].o < of)
870+
c++;
871+
else
872+
of = it.next();
873+
}
874+
// increment the c pointer until it is pointing at least to last point or is done.
875+
while(c < points.length && points[c].o < last)
876+
c++;
877+
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
878+
}
879+
880+
private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
881+
while(c < points.length && points[c].o == of) {
882+
_dict.putSparse(sr, did, points[c].r, nCol, _colIndexes);
883+
c++;
884+
}
885+
return c;
858886
}
859887

860-
private int processRow(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
888+
private int processRowDense(P[] points, final DenseBlock dr, final int nCol, int c, int of, final int did) {
861889
while(c < points.length && points[c].o == of) {
862-
_dict.put(sr, did, points[c].r, nCol, _colIndexes);
890+
_dict.putDense(dr, did, points[c].r, nCol, _colIndexes);
863891
c++;
864892
}
865893
return c;

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.Serializable;
2323

2424
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
25+
import org.apache.sysds.runtime.data.DenseBlock;
2526
import org.apache.sysds.runtime.data.SparseBlock;
2627
import org.apache.sysds.runtime.functionobjects.ValueFunction;
2728
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -87,8 +88,17 @@ public static void correctNan(double[] res, IColIndex colIndexes) {
8788
}
8889

8990
@Override
90-
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
91+
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
9192
for(int i = 0; i < nCol; i++)
9293
sb.append(rowOut, columns.get(i), getValue(idx, i, nCol));
9394
}
95+
96+
@Override
97+
public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) {
98+
double[] dv = dr.values(rowOut);
99+
int off = dr.pos(rowOut);
100+
for(int i = 0; i < nCol; i++)
101+
dv[off + columns.get(i)] += getValue(idx, i, nCol);
102+
}
103+
94104
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.commons.logging.Log;
2626
import org.apache.commons.logging.LogFactory;
2727
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
28+
import org.apache.sysds.runtime.data.DenseBlock;
2829
import org.apache.sysds.runtime.data.SparseBlock;
2930
import org.apache.sysds.runtime.functionobjects.Builtin;
3031
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -989,6 +990,18 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef
989990
* @param nCol The number of columns in the dictionary
990991
* @param columns The columns to output into.
991992
*/
992-
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);
993+
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);
994+
995+
/**
996+
* Put the row specified into the sparse block, via append calls.
997+
*
998+
* @param db The dense block to put into
999+
* @param idx The dictionary index to put in.
1000+
* @param rowOut The row in the sparse block to put it into
1001+
* @param nCol The number of columns in the dictionary
1002+
* @param columns The columns to output into.
1003+
*/
1004+
public void putDense(DenseBlock db, int idx, int rowOut, int nCol, IColIndex columns);
1005+
9931006

9941007
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.io.Serializable;
2626

2727
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
28+
import org.apache.sysds.runtime.data.DenseBlock;
2829
import org.apache.sysds.runtime.data.SparseBlock;
2930
import org.apache.sysds.runtime.functionobjects.Builtin;
3031
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -526,7 +527,12 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex
526527
}
527528

528529
@Override
529-
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
530+
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
531+
throw new RuntimeException(errMessage);
532+
}
533+
534+
@Override
535+
public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
530536
throw new RuntimeException(errMessage);
531537
}
532538
}

src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,13 @@ public void MMDictScalingSparse() {
490490
}
491491

492492
@Test(expected = Exception.class)
493-
public void put() {
494-
d.put(null, 1, 1, 1, null);
493+
public void putDense() {
494+
d.putDense(null, 1, 1, 1, null);
495+
}
496+
497+
@Test(expected = Exception.class)
498+
public void putSparse() {
499+
d.putSparse(null, 1, 1, 1, null);
495500
}
496501

497502
@Test

0 commit comments

Comments
 (0)