Skip to content

Commit 6c81789

Browse files
committed
more tests
1 parent 545b2b6 commit 6c81789

File tree

7 files changed

+173
-45
lines changed

7 files changed

+173
-45
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public final CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] co
5454

5555
@Override
5656
public final boolean equals(Object o) {
57-
if(o != null && o instanceof IDictionary)
57+
if(o != null && o instanceof IDictionary)
5858
return equals((IDictionary) o);
5959
return false;
6060
}
@@ -279,8 +279,8 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int
279279

280280
@Override
281281
public boolean containsValueWithReference(double pattern, double[] reference) {
282-
if(Double.isNaN(pattern)){
283-
for(int i = 0 ; i < reference.length; i++)
282+
if(Double.isNaN(pattern)) {
283+
for(int i = 0; i < reference.length; i++)
284284
if(Double.isNaN(reference[i]))
285285
return true;
286286
return containsValue(pattern);

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) {
465465
final double[] ret = new double[numVals + 1];
466466
for(int k = 0; k < numVals; k++)
467467
ret[k] = prodRow(k, nCol);
468-
ret[ret.length - 1] = defaultTuple[0];
468+
ret[numVals] = defaultTuple[0];
469469
for(int i = 1; i < nCol; i++)
470-
ret[ret.length - 1] *= defaultTuple[i];
470+
ret[numVals] *= defaultTuple[i];
471471
return ret;
472472
}
473473

@@ -522,9 +522,10 @@ private double sumRowSq(int k, int nrColumns) {
522522

523523
private double prodRow(int k, int nrColumns) {
524524
final int valOff = k * nrColumns;
525+
final int end = valOff + nrColumns;
525526
double res = _values[valOff];
526-
for(int i = 1; i < nrColumns; i++)
527-
res *= _values[valOff + i];
527+
for(int i = valOff + 1; i < end && res != 0; i++) // early abort on zero
528+
res *= _values[i];
528529
return res;
529530
}
530531

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public DictType getDictType() {
151151

152152
@Override
153153
public double[] sumAllRowsToDouble(int nrColumns) {
154-
double[] ret = new double[nRowCol];
154+
double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)];
155155
Arrays.fill(ret, l, u, 1.0);
156156
return ret;
157157
}
@@ -183,19 +183,29 @@ public double[] sumAllRowsToDoubleWithReference(double[] reference) {
183183

184184
@Override
185185
public double[] sumAllRowsToDoubleSq(int nrColumns) {
186-
double[] ret = new double[nRowCol];
186+
double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)];
187187
Arrays.fill(ret, l, u, 1);
188188
return ret;
189189
}
190190

191191
@Override
192192
public double[] productAllRowsToDouble(int nCol) {
193-
return new double[nRowCol];
193+
double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)];
194+
if(u - l - 1 == 0)
195+
ret[l] = 1;
196+
return ret;
194197
}
195198

196199
@Override
197200
public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) {
198-
return new double[nRowCol];
201+
int nVal = nRowCol + (withEmpty ? 1 : 0);
202+
double[] ret = new double[nVal + 1];
203+
if(u - l - 1 == 0)
204+
ret[l] = 1;
205+
ret[nVal] = defaultTuple[0];
206+
for(int i = 1; i < defaultTuple.length; i++)
207+
ret[nVal] *= defaultTuple[i];
208+
return ret;
199209
}
200210

201211
@Override

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import org.apache.sysds.runtime.data.SparseBlockCSR;
4242
import org.apache.sysds.runtime.data.SparseBlockFactory;
4343
import org.apache.sysds.runtime.data.SparseBlockMCSR;
44+
import org.apache.sysds.runtime.data.SparseRow;
45+
import org.apache.sysds.runtime.data.SparseRowScalar;
4446
import org.apache.sysds.runtime.functionobjects.Builtin;
4547
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
4648
import org.apache.sysds.runtime.functionobjects.Divide;
@@ -1143,17 +1145,31 @@ public double[] productAllRowsToDouble(int nCol) {
11431145
}
11441146

11451147
private final void productAllRowsToDouble(double[] ret, int nCol) {
1148+
final int nRow = _data.getNumRows();
1149+
11461150
if(_data.isInSparseFormat()) {
11471151
SparseBlock sb = _data.getSparseBlock();
1148-
for(int i = 0; i < _data.getNumRows(); i++) {
1149-
if(!sb.isEmpty(i) && sb.size(i) == nCol) {
1152+
for(int i = 0; i < nRow; i++) {
1153+
if(!sb.isEmpty(i)) {
11501154
// if not equal to nCol ... skip
11511155
final int apos = sb.pos(i);
11521156
final int alen = sb.size(i) + apos;
1157+
final int[] aix = sb.indexes(i);
11531158
final double[] avals = sb.values(i);
11541159
ret[i] = 1;
1155-
for(int j = apos; j < alen; j++) {
1160+
int pj = 0;
1161+
// many extra cases to handle NaN...
1162+
for(int j = apos; j < alen && !Double.isNaN(ret[i]); j++) {
1163+
if(aix[j] - pj >= 1) {
1164+
ret[i] = 0;
1165+
break;
1166+
}
11561167
ret[i] *= avals[j];
1168+
pj = aix[j];
1169+
}
1170+
1171+
if(!Double.isNaN(ret[i]) && sb.size(i) != nCol) {
1172+
ret[i] = 0;
11571173
}
11581174
}
11591175
else
@@ -1163,9 +1179,9 @@ private final void productAllRowsToDouble(double[] ret, int nCol) {
11631179
else {
11641180
double[] values = _data.getDenseBlockValues();
11651181
int off = 0;
1166-
for(int k = 0; k < _data.getNumRows(); k++) {
1182+
for(int k = 0; k < nRow; k++) {
11671183
ret[k] = 1;
1168-
for(int j = 0; j < _data.getNumColumns(); j++) {
1184+
for(int j = 0; j < nCol && ret[k] != 0; j++) { // early abort on zero
11691185
final double v = values[off++];
11701186
ret[k] *= v;
11711187
}
@@ -1175,11 +1191,12 @@ private final void productAllRowsToDouble(double[] ret, int nCol) {
11751191

11761192
@Override
11771193
public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) {
1178-
double[] ret = new double[_data.getNumRows() + 1];
1194+
final int nRow = _data.getNumRows();
1195+
double[] ret = new double[nRow + 1];
11791196
productAllRowsToDouble(ret, defaultTuple.length);
1180-
ret[_data.getNumRows()] = defaultTuple[0];
1197+
ret[nRow] = defaultTuple[0];
11811198
for(int j = 1; j < defaultTuple.length; j++)
1182-
ret[_data.getNumRows()] *= defaultTuple[j];
1199+
ret[nRow] *= defaultTuple[j];
11831200

11841201
return ret;
11851202
}
@@ -1289,7 +1306,7 @@ public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) {
12891306
final int[] aix = sb.indexes(i);
12901307
final double[] avals = sb.values(i);
12911308
for(int j = apos; j < alen; j++) {
1292-
c[colIndexes.get(aix[j])] += count * avals[j] * avals[j];
1309+
c[colIndexes.get(aix[j])] += avals[j] * avals[j] * count;
12931310
}
12941311
}
12951312
}
@@ -1847,12 +1864,12 @@ public IDictionary scaleTuples(int[] scaling, int nCol) {
18471864
if(!sbThis.isEmpty(i)) {
18481865
sbRet.set(i, sbThis.get(i), true);
18491866

1850-
final int count = scaling[i];
1867+
final int sc = scaling[i];
18511868
final int apos = sbRet.pos(i);
18521869
final int alen = sbRet.size(i) + apos;
18531870
final double[] avals = sbRet.values(i);
18541871
for(int j = apos; j < alen; j++)
1855-
avals[j] = count * avals[j];
1872+
avals[j] = sc * avals[j];
18561873
}
18571874
}
18581875
retBlock.setNonZeros(_data.getNonZeros());
@@ -2594,30 +2611,25 @@ public IDictionary reorder(int[] reorder) {
25942611

25952612
@Override
25962613
public IDictionary append(double[] row) {
2597-
if(_data.isEmpty()) {
2598-
throw new NotImplementedException();
2599-
}
2600-
else if(_data.isInSparseFormat()) {
2614+
if(_data.isInSparseFormat()) {
26012615
final int nRow = _data.getNumRows();
2602-
if(_data.getSparseBlock() instanceof SparseBlockMCSR) {
2603-
MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), true);
2604-
mb.allocateBlock();
2605-
SparseBlock sb = mb.getSparseBlock();
2606-
SparseBlockMCSR s = (SparseBlockMCSR) _data.getSparseBlock();
2607-
2608-
for(int i = 0; i < _data.getNumRows(); i++)
2609-
sb.set(i, s.get(i), false);
2610-
2611-
for(int i = 0; i < row.length; i++)
2612-
sb.set(nRow, i, row[i]);
2613-
2614-
mb.examSparsity();
2615-
return new MatrixBlockDictionary(mb);
2616-
2617-
}
2618-
else {
2619-
throw new NotImplementedException("Not implemented append for CSR");
2620-
}
2616+
final int nCol = _data.getNumColumns();
2617+
SparseRow sr = null;
2618+
for(int i = 0; i < row.length; i++) {
2619+
if(row[i] != 0) {
2620+
if(sr == null)
2621+
sr = new SparseRowScalar(i, row[i]);
2622+
else
2623+
sr = sr.append(i, row[i]);
2624+
}
2625+
}
2626+
MatrixBlock mb = new MatrixBlock(_data.getNumRows() + 1, _data.getNumColumns(), true);
2627+
mb.allocateBlock();
2628+
SparseBlock sb = mb.getSparseBlock();
2629+
mb.copy(0, nRow, 0, nCol, _data, false);
2630+
sb.set(nRow, sr, false);
2631+
mb.examSparsity();
2632+
return new MatrixBlockDictionary(mb);
26212633

26222634
}
26232635
else {

src/test/java/org/apache/sysds/test/TestUtils.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,6 +2137,17 @@ public static double[] generateTestVector(int cols, double min, double max, doub
21372137
return vector;
21382138
}
21392139

2140+
public static int[] generateTestIntVector(int cols, int min, int max, double sparsity, long seed) {
2141+
int[] vector = new int[cols];
2142+
Random random = (seed == -1) ? TestUtils.random : new Random(seed);
2143+
for(int j = 0; j < cols; j++) {
2144+
if(random.nextDouble() > sparsity)
2145+
continue;
2146+
vector[j] = (random.nextInt(max - min) + min);
2147+
}
2148+
return vector;
2149+
}
2150+
21402151
/**
21412152
*
21422153
* Generates a test matrix with the specified parameters as a MatrixBlock.

src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,4 +642,10 @@ public void createDictionary() {
642642
assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(1, true, -13, 0));
643643
assertThrows(RuntimeException.class, () -> IdentityDictionarySlice.create(10, true, 4, 11));
644644
}
645+
646+
647+
@Test
648+
public void notEqualsObject(){
649+
assertNotEquals(Dictionary.create(new double[]{1.1,2.2,3.3}), new Object());
650+
}
645651
}

src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,4 +1771,92 @@ public void scalarOp(ScalarOperator op) {
17711771
bb = b.applyScalarOpWithReference(op, ref1, ref2);
17721772
compare(aa, bb, nCol);
17731773
}
1774+
1775+
@Test
1776+
public void scaleTuples() {
1777+
IDictionary aa;
1778+
IDictionary bb;
1779+
1780+
int[] scale = TestUtils.generateTestIntVector(nRow, 1, 10, 1, 3213);
1781+
aa = a.scaleTuples(scale, nCol);
1782+
bb = b.scaleTuples(scale, nCol);
1783+
compare(aa, bb, nCol);
1784+
}
1785+
1786+
// productAllRowsToDouble
1787+
@Test
1788+
public void productRows() {
1789+
double[] aa;
1790+
double[] bb;
1791+
String err = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName();
1792+
// int[] scale = TestUtils.generateTestIntVector(nRow, 1, 10, 1, 3213);
1793+
aa = a.productAllRowsToDouble(nCol);
1794+
bb = b.productAllRowsToDouble(nCol);
1795+
assertArrayEquals(err, aa, bb, 0.0000001);
1796+
1797+
double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1, 3216245);
1798+
aa = a.productAllRowsToDoubleWithDefault(def);
1799+
bb = b.productAllRowsToDoubleWithDefault(def);
1800+
assertArrayEquals(err, aa, bb, 0.0000001);
1801+
1802+
double[] ref = TestUtils.generateTestVector(nCol, 1, 10, 1, 3216245);
1803+
aa = a.productAllRowsToDoubleWithReference(ref);
1804+
bb = b.productAllRowsToDoubleWithReference(ref);
1805+
assertArrayEquals(err, aa, bb, 0.0000001);
1806+
}
1807+
1808+
@Test
1809+
public void appendRow() {
1810+
double[] r = TestUtils.generateTestVector(nCol, 1, 10, 0.9, 2222);
1811+
IDictionary aa = a.append(r);
1812+
IDictionary bb = b.append(r);
1813+
1814+
compare(aa, bb, nCol);
1815+
1816+
for(int i = 0; i < nCol; i++) {
1817+
assertEquals(r[i], aa.getValue(nRow, i, nCol), 0.0);
1818+
assertEquals(r[i], bb.getValue(nRow, i, nCol), 0.0);
1819+
}
1820+
}
1821+
1822+
@Test
1823+
public void colSumSq() {
1824+
double[] aa = new double[nCol + 2];
1825+
double[] bb = new double[nCol + 2];
1826+
int[] counts = getCounts(nRow, 321652);
1827+
a.colSumSq(aa, counts, ColIndexFactory.create(nCol));
1828+
b.colSumSq(bb, counts, ColIndexFactory.create(nCol));
1829+
assertArrayEquals(aa, bb, 0.0000001);
1830+
}
1831+
1832+
@Test
1833+
public void multiplyScalar() {
1834+
double[] aa = new double[(nCol + 1) * 4];
1835+
double[] bb = new double[(nCol + 1) * 4];
1836+
Random r = new Random(3222);
1837+
for(int i = 0; i < 10; i++) {
1838+
int di = r.nextInt(nRow);
1839+
int ur = r.nextInt(4);
1840+
a.multiplyScalar(32, aa, ur, di, ColIndexFactory.create(nCol).shift(1));
1841+
b.multiplyScalar(32, bb, ur, di, ColIndexFactory.create(nCol).shift(1));
1842+
}
1843+
assertArrayEquals(aa, bb, 0.0000001);
1844+
1845+
}
1846+
1847+
@Test
1848+
public void subtractTuple(){
1849+
double[] r = TestUtils.generateTestVector(nCol, 1, 10, 0.9, 222);
1850+
IDictionary aa = a.subtractTuple(r);
1851+
IDictionary bb = b.subtractTuple(r);
1852+
1853+
compare(aa, bb, nCol);
1854+
}
1855+
1856+
@Test
1857+
public void cbind(){
1858+
IDictionary aa = a.cbind(b, nCol);
1859+
IDictionary bb = b.cbind(a, nCol);
1860+
compare(aa, bb, nCol * 2);
1861+
}
17741862
}

0 commit comments

Comments
 (0)