Skip to content

Commit 54c8696

Browse files
committed
[MINOR] Compressed tests
This commit follows up on the rexpand instruction to improve the test coverage and fix a few bugs in CLA. Closes #2214
1 parent b43fa11 commit 54c8696

25 files changed

+1042
-203
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ public int getNumValues() {
5555
}
5656

5757
/**
58-
* Returns the counts of values inside the dictionary. If already calculated it will return the previous counts.
59-
* This produce an overhead in cases where the count is calculated, but the overhead will be limited to number of
60-
* distinct tuples in the dictionary.
58+
* Returns the counts of values inside the dictionary. If already calculated it will return the previous counts. This
59+
* produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct
60+
* tuples in the dictionary.
6161
*
6262
* The returned counts always contains the number of zero tuples as well if there are some contained, even if they
6363
* are not materialized.
@@ -195,16 +195,16 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
195195

196196
@Override
197197
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
198-
try {
199-
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
200-
if(d == null)
201-
return ColGroupEmpty.create(max);
202-
else
203-
return copyAndSet(ColIndexFactory.create(max), d);
204-
}
205-
catch(DMLCompressionException e) {
198+
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
199+
if(d == null) {
200+
if(max <= 0)
201+
return null;
206202
return ColGroupEmpty.create(max);
207203
}
204+
else {
205+
IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
206+
return copyAndSet(outCols, d);
207+
}
208208
}
209209

210210
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,11 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
527527
@Override
528528
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
529529
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
530-
if(d == null)
530+
if(d == null){
531+
if(max <= 0)
532+
return null;
531533
return ColGroupEmpty.create(max);
534+
}
532535
else
533536
return create(max, d);
534537
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.List;
2727

2828
import org.apache.commons.lang3.NotImplementedException;
29-
import org.apache.sysds.runtime.DMLRuntimeException;
3029
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P;
3130
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
3231
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
@@ -392,33 +391,15 @@ public AColGroup extractCommon(double[] constV) {
392391
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
393392
final int def = (int) _reference[0];
394393
IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, def);
395-
396394
if(d == null) {
397-
if(def <= 0 || def > max)
398-
return ColGroupEmpty.create(max);
399-
else {
400-
double[] retDef = new double[max];
401-
retDef[def - 1] = 1;
402-
return ColGroupConst.create(retDef);
403-
}
395+
if(max <= 0)
396+
return null;
397+
return ColGroupEmpty.create(max);
404398
}
405399
else {
406-
IColIndex outCols = ColIndexFactory.create(max);
407-
if(def <= 0) {
408-
if(ignore)
409-
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
410-
else
411-
throw new DMLRuntimeException("Invalid content of zero in rexpand");
412-
}
413-
else if(def > max)
414-
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
415-
else {
416-
double[] retDef = new double[max];
417-
retDef[def - 1] = 1;
418-
return ColGroupDDCFOR.create(outCols, d, _data, getCachedCounts(), retDef);
419-
}
400+
IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
401+
return ColGroupDDC.create(outCols, d, _data, getCachedCounts());
420402
}
421-
422403
}
423404

424405
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,24 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {
500500
@Override
501501
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
502502
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
503-
return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0]);
503+
return rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(), (int) _defaultTuple[0],
504+
_dict.getNumberOfValues(1));
504505
}
505506

506507
protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows, IDictionary d,
507-
AOffset indexes, AMapToData data, int[] counts, int def) {
508+
AOffset indexes, AMapToData data, int[] counts, int def, int nVal) {
508509

509510
if(d == null) {
510-
if(def <= 0 || def > max)
511+
if(def <= 0){
512+
if(max > 0)
513+
return ColGroupEmpty.create(max);
514+
else
515+
return null;
516+
}
517+
else if(def > max && max > 0)
511518
return ColGroupEmpty.create(max);
519+
else if(max <= 0)
520+
return null;
512521
else {
513522
double[] retDef = new double[max];
514523
retDef[def - 1] = 1;
@@ -517,7 +526,7 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in
517526
}
518527
}
519528
else {
520-
final IColIndex outCols = ColIndexFactory.create(max);
529+
final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(nVal));
521530
if(def <= 0) {
522531
if(ignore)
523532
return ColGroupSDCZeros.create(outCols, nRows, d, indexes, data, counts);

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ public AColGroup extractCommon(double[] constV) {
427427
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
428428
IDictionary d = _dict.rexpandColsWithReference(max, ignore, cast, (int) _reference[0]);
429429
return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, _indexes, _data, getCachedCounts(),
430-
(int) _reference[0]);
430+
(int) _reference[0], _dict.getNumberOfValues(1));
431431
}
432432

433433
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public static AColGroup create(IColIndex colIndexes, int numRows, IDictionary di
8585
if(offsets instanceof OffsetEmpty)
8686
return ColGroupConst.create(colIndexes, defaultTuple);
8787
final boolean allZero = ColGroupUtils.allZero(defaultTuple);
88-
if(dict == null && allZero)
88+
if(dict == null && allZero)
8989
return new ColGroupEmpty(colIndexes);
9090
else if(dict == null && offsets.getSize() * 2 > numRows + 2) {
9191
AOffset rev = offsets.reverse(numRows);
@@ -469,27 +469,36 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
469469
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
470470
final int def = (int) _defaultTuple[0];
471471
if(d == null) {
472-
if(def <= 0 || def > max)
472+
if(def <= 0){
473+
if(max > 0)
474+
return ColGroupEmpty.create(max);
475+
else
476+
return null;
477+
}
478+
else if(def > max && max > 0)
473479
return ColGroupEmpty.create(max);
480+
else if(max <= 0)
481+
return null;
474482
else {
475483
double[] retDef = new double[max];
476484
retDef[((int) _defaultTuple[0]) - 1] = 1;
477485
return ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, null, retDef, _indexes, null);
478486
}
479487
}
480488
else {
489+
final IColIndex outCols = ColIndexFactory.create(d.getNumberOfColumns(_dict.getNumberOfValues(1)));
481490
if(def <= 0) {
482491
if(ignore)
483-
return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts());
492+
return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts());
484493
else
485494
throw new DMLRuntimeException("Invalid content of zero in rexpand");
486495
}
487496
else if(def > max)
488-
return ColGroupSDCSingleZeros.create(ColIndexFactory.create(max), nRows, d, _indexes, getCachedCounts());
497+
return ColGroupSDCSingleZeros.create(outCols, nRows, d, _indexes, getCachedCounts());
489498
else {
490499
double[] retDef = new double[max];
491500
retDef[((int) _defaultTuple[0]) - 1] = 1;
492-
return ColGroupSDCSingle.create(ColIndexFactory.create(max), nRows, d, retDef, _indexes, getCachedCounts());
501+
return ColGroupSDCSingle.create(outCols, nRows, d, retDef, _indexes, getCachedCounts());
493502
}
494503
}
495504
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ public int getNumberOfValues(int ncol) {
9797
return _values.length / ncol;
9898
}
9999

100+
@Override
101+
public int getNumberOfColumns(int nrow){
102+
return _values.length / nrow;
103+
}
104+
100105
@Override
101106
public String getString(int colIndexes) {
102107
throw new NotImplementedException();

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
4242
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
4343
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
44-
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
44+
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
4545
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
4646
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
4747
import org.apache.sysds.utils.MemoryEstimates;
@@ -388,6 +388,11 @@ public int getNumberOfValues(int nCol) {
388388
return _values.length / nCol;
389389
}
390390

391+
@Override
392+
public int getNumberOfColumns(int nrow) {
393+
return _values.length / nrow;
394+
}
395+
391396
@Override
392397
public double[] sumAllRowsToDouble(int nrColumns) {
393398
if(nrColumns == 1)
@@ -1120,8 +1125,11 @@ public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cas
11201125
MatrixBlockDictionary m = getMBDict(1);
11211126
if(m == null)
11221127
return null;
1123-
IDictionary a = m.applyScalarOp(new LeftScalarOperator(Plus.getPlusFnObject(), reference));
1124-
return a == null ? null : a.rexpandCols(max, ignore, cast, 1);
1128+
IDictionary a = m.applyScalarOp(new RightScalarOperator(Plus.getPlusFnObject(), reference));
1129+
if(a == null)
1130+
return null; // second ending
1131+
a = a.rexpandCols(max, ignore, cast, 1);
1132+
return a;
11251133
}
11261134

11271135
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,14 @@ public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIn
327327
*/
328328
public int getNumberOfValues(int ncol);
329329

330+
/**
331+
* Get the number of columns in this dictionary, provided you know the number of values, or rows.
332+
*
333+
* @param nrow The number of rows/values known inside this dictionary
334+
* @return The number of columns
335+
*/
336+
public int getNumberOfColumns(int nrow);
337+
330338
/**
331339
* Method used as a pre-aggregate of each tuple in the dictionary, to single double values.
332340
*

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,13 @@ public int getNumberOfValues(int ncol) {
194194
return nRowCol + (withEmpty ? 1 : 0);
195195
}
196196

197+
@Override
198+
public int getNumberOfColumns(int nrow) {
199+
if(nrow != (nRowCol + (withEmpty ? 1 : 0)))
200+
throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns");
201+
return nRowCol;
202+
}
203+
197204
@Override
198205
public double[] sumAllRowsToDouble(int nrColumns) {
199206
if(withEmpty) {

0 commit comments

Comments
 (0)