Skip to content

Commit 5124ca5

Browse files
committed
more tests
1 parent fc442b9 commit 5124ca5

File tree

4 files changed

+309
-114
lines changed

4 files changed

+309
-114
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int
279279

280280
@Override
281281
public boolean containsValueWithReference(double pattern, double[] reference) {
282+
if(Double.isNaN(pattern)){
283+
for(int i = 0 ; i < reference.length; i++)
284+
if(Double.isNaN(reference[i]))
285+
return true;
286+
return containsValue(pattern);
287+
}
282288
return getMBDict().containsValueWithReference(pattern, reference);
283289
}
284290

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,8 @@ public boolean containsValue(double pattern) {
726726

727727
@Override
728728
public boolean containsValueWithReference(double pattern, double[] reference) {
729+
if(Double.isNaN(pattern))
730+
return super.containsValueWithReference(pattern, reference);
729731
final int nCol = reference.length;
730732
for(int i = 0; i < _values.length; i++)
731733
if(_values[i] + reference[i % nCol] == pattern)
@@ -913,46 +915,7 @@ public IDictionary replaceWithReference(double pattern, double replace, double[]
913915
final int nCol = reference.length;
914916
final int nRow = _values.length / nCol;
915917
if(Util.eq(pattern, Double.NaN)) {
916-
Set<Integer> colsWithNan = null;
917-
for(int i = 0; i < reference.length; i++) {
918-
if(Util.eq(reference[i], Double.NaN)) {
919-
if(colsWithNan == null)
920-
colsWithNan = new HashSet<>();
921-
colsWithNan.add(i);
922-
reference[i] = replace;
923-
}
924-
}
925-
926-
if(colsWithNan != null) {
927-
final double[] retV = new double[_values.length];
928-
for(int i = 0; i < nRow; i++) {
929-
final int off = i * reference.length;
930-
for(int j = 0; j < nCol; j++) {
931-
final int cell = off + j;
932-
if(colsWithNan.contains(j))
933-
retV[cell] = 0;
934-
else if(Util.eq(_values[cell], Double.NaN))
935-
retV[cell] = replace - reference[j];
936-
else
937-
retV[cell] = _values[cell];
938-
}
939-
}
940-
return create(retV);
941-
}
942-
else {
943-
final double[] retV = new double[_values.length];
944-
for(int i = 0; i < nRow; i++) {
945-
final int off = i * reference.length;
946-
for(int j = 0; j < nCol; j++) {
947-
final int cell = off + j;
948-
if(Util.eq(_values[cell], Double.NaN))
949-
retV[cell] = replace - reference[j];
950-
else
951-
retV[cell] = _values[cell] ;
952-
}
953-
}
954-
return create(retV);
955-
}
918+
return replaceWithReferenceNaN(replace, reference, nCol, nRow);
956919
}
957920
else {
958921
final double[] retV = new double[_values.length];
@@ -969,6 +932,62 @@ else if(Util.eq(_values[cell], Double.NaN))
969932
}
970933
}
971934

935+
private IDictionary replaceWithReferenceNaN(double replace, double[] reference, final int nCol, final int nRow) {
936+
final Set<Integer> colsWithNan = getColsWithNan(replace, reference);
937+
final double[] retV;
938+
if(colsWithNan != null) {
939+
if(colsWithNan.size() == nCol && replace == 0)
940+
return null;
941+
retV = new double[_values.length];
942+
replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, _values, retV);
943+
}
944+
else {
945+
retV = new double[_values.length];
946+
replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, _values);
947+
}
948+
return create(retV);
949+
}
950+
951+
protected static Set<Integer> getColsWithNan(double replace, double[] reference) {
952+
Set<Integer> colsWithNan = null;
953+
for(int i = 0; i < reference.length; i++) {
954+
if(Util.eq(reference[i], Double.NaN)) {
955+
if(colsWithNan == null)
956+
colsWithNan = new HashSet<>();
957+
colsWithNan.add(i);
958+
reference[i] = replace;
959+
}
960+
}
961+
return colsWithNan;
962+
}
963+
964+
protected static void replaceWithReferenceNanDenseWithoutNanCols(final double replace, final double[] reference,
965+
final int nRow, final int nCol, final double[] retV, final double[] values) {
966+
int off = 0;
967+
for(int i = 0; i < nRow; i++) {
968+
for(int j = 0; j < nCol; j++) {
969+
final double v = values[off];
970+
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
971+
}
972+
}
973+
}
974+
975+
protected static void replaceWithReferenceNanDenseWithNanCols(final double replace, final double[] reference,
976+
final int nRow, final int nCol, Set<Integer> colsWithNan, final double[] values, final double[] retV) {
977+
int off = 0;
978+
for(int i = 0; i < nRow; i++) {
979+
for(int j = 0; j < nCol; j++) {
980+
final double v = values[off];
981+
if(colsWithNan.contains(j))
982+
retV[off++] = 0;
983+
else if(Util.eq(v, Double.NaN))
984+
retV[off++] = replace - reference[j];
985+
else
986+
retV[off++] = v;
987+
}
988+
}
989+
}
990+
972991
@Override
973992
public void product(double[] ret, int[] counts, int nCol) {
974993
if(ret[0] == 0)
@@ -1024,17 +1043,22 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
10241043
if(ret[0] == 0)
10251044
return;
10261045
final MathContext cont = MathContext.DECIMAL128;
1027-
final int len = counts.length;
1046+
final int nRow = counts.length;
10281047
final int nCol = reference.length;
1048+
10291049
BigDecimal tmp = BigDecimal.ONE;
10301050
int off = 0;
1031-
for(int i = 0; i < len; i++) {
1051+
for(int i = 0; i < nRow; i++) {
10321052
for(int j = 0; j < nCol; j++) {
10331053
final double v = _values[off++] + reference[j];
10341054
if(v == 0) {
10351055
ret[0] = 0;
10361056
return;
10371057
}
1058+
else if(!Double.isFinite(v)) {
1059+
ret[0] = v;
1060+
return;
1061+
}
10381062
tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont);
10391063
}
10401064
}
@@ -1044,6 +1068,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
10441068
ret[0] = 0;
10451069
else if(!Double.isInfinite(ret[0]))
10461070
ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue();
1071+
10471072
}
10481073

10491074
@Override
@@ -1192,7 +1217,7 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef
11921217
public boolean equals(IDictionary o) {
11931218
if(o instanceof Dictionary)
11941219
return Arrays.equals(_values, ((Dictionary) o)._values);
1195-
else if (o != null)
1220+
else if(o != null)
11961221
return o.equals(this);
11971222
return false;
11981223
}
@@ -1219,7 +1244,7 @@ public IDictionary reorder(int[] reorder) {
12191244
return ret;
12201245
}
12211246

1222-
@Override
1247+
@Override
12231248
protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols,
12241249
IColIndex aggregateColumns) {
12251250

@@ -1264,7 +1289,7 @@ private void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex agg
12641289
retIdx = 0;
12651290
}
12661291

1267-
@Override
1292+
@Override
12681293
protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols,
12691294
int nColRight) {
12701295
final int thisColsSize = thisCols.size();
@@ -1291,15 +1316,14 @@ protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b
12911316
return Dictionary.create(ret);
12921317
}
12931318

1294-
private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) {
1319+
private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) {
12951320
if(v != 0) {
12961321
for(int k = sPos; k < sEnd; k++) { // cols right with value
12971322
ret[offOut + sIdx[k]] += v * sVals[k];
12981323
}
12991324
}
13001325
}
13011326

1302-
13031327
@Override
13041328
public IDictionary append(double[] row) {
13051329
double[] retV = new double[_values.length + row.length];
@@ -1308,5 +1332,4 @@ public IDictionary append(double[] row) {
13081332
return new Dictionary(retV);
13091333
}
13101334

1311-
13121335
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.math.BigDecimal;
2626
import java.math.MathContext;
2727
import java.util.Arrays;
28-
import java.util.HashSet;
2928
import java.util.Set;
3029

3130
import org.apache.commons.lang3.NotImplementedException;
@@ -36,6 +35,7 @@
3635
import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex;
3736
import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex;
3837
import org.apache.sysds.runtime.compress.utils.Util;
38+
import org.apache.sysds.runtime.data.DenseBlock;
3939
import org.apache.sysds.runtime.data.DenseBlockFP64;
4040
import org.apache.sysds.runtime.data.SparseBlock;
4141
import org.apache.sysds.runtime.data.SparseBlockCSR;
@@ -1495,6 +1495,8 @@ public boolean containsValue(double pattern) {
14951495

14961496
@Override
14971497
public boolean containsValueWithReference(double pattern, double[] reference) {
1498+
if(Double.isNaN(pattern))
1499+
return super.containsValueWithReference(pattern, reference);
14981500
if(_data.isInSparseFormat()) {
14991501
final SparseBlock sb = _data.getSparseBlock();
15001502
for(int i = 0; i < _data.getNumRows(); i++) {
@@ -2059,9 +2061,8 @@ public IDictionary replace(double pattern, double replace, int nCol) {
20592061

20602062
@Override
20612063
public IDictionary replaceWithReference(double pattern, double replace, double[] reference) {
2062-
if(Util.eq(pattern, Double.NaN)) {
2064+
if(Util.eq(pattern, Double.NaN))
20632065
return replaceWithReferenceNan(replace, reference);
2064-
}
20652066

20662067
final int nRow = _data.getNumRows();
20672068
final int nCol = _data.getNumColumns();
@@ -2108,27 +2109,19 @@ public IDictionary replaceWithReference(double pattern, double replace, double[]
21082109
}
21092110

21102111
private IDictionary replaceWithReferenceNan(double replace, double[] reference) {
2111-
2112+
final Set<Integer> colsWithNan = Dictionary.getColsWithNan(replace, reference);
21122113
final int nRow = _data.getNumRows();
21132114
final int nCol = _data.getNumColumns();
2115+
if(colsWithNan != null && colsWithNan.size() == nCol && replace == 0)
2116+
return null;
2117+
21142118
final MatrixBlock ret = new MatrixBlock(nRow, nCol, false);
21152119
ret.allocateDenseBlock();
2116-
2117-
Set<Integer> colsWithNan = null;
2118-
for(int i = 0; i < reference.length; i++) {
2119-
if(Util.eq(reference[i], Double.NaN)) {
2120-
if(colsWithNan == null)
2121-
colsWithNan = new HashSet<>();
2122-
colsWithNan.add(i);
2123-
reference[i] = replace;
2124-
}
2125-
}
2120+
final double[] retV = ret.getDenseBlockValues();
21262121

21272122
if(colsWithNan == null) {
2128-
2129-
final double[] retV = ret.getDenseBlockValues();
2130-
int off = 0;
21312123
if(_data.isInSparseFormat()) {
2124+
final DenseBlock db = ret.getDenseBlock();
21322125
final SparseBlock sb = _data.getSparseBlock();
21332126
for(int i = 0; i < nRow; i++) {
21342127
if(sb.isEmpty(i))
@@ -2137,30 +2130,22 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference)
21372130
final int apos = sb.pos(i);
21382131
final int alen = sb.size(i) + apos;
21392132
final double[] avals = sb.values(i);
2133+
final int[] aix = sb.indexes(i);
21402134
int j = 0;
2135+
int off = db.pos(i);
21412136
for(int k = apos; k < alen; k++) {
21422137
final double v = avals[k];
2143-
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
2138+
retV[off + aix[k]] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
21442139
}
21452140
}
21462141
}
21472142
else {
21482143
final double[] values = _data.getDenseBlockValues();
2149-
for(int i = 0; i < nRow; i++) {
2150-
for(int j = 0; j < nCol; j++) {
2151-
final double v = values[off];
2152-
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
2153-
}
2154-
}
2144+
Dictionary.replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, values);
21552145
}
21562146

2157-
ret.recomputeNonZeros();
2158-
ret.examSparsity();
2159-
return MatrixBlockDictionary.create(ret);
21602147
}
21612148
else {
2162-
2163-
final double[] retV = ret.getDenseBlockValues();
21642149
if(_data.isInSparseFormat()) {
21652150
final SparseBlock sb = _data.getSparseBlock();
21662151
for(int i = 0; i < nRow; i++) {
@@ -2170,10 +2155,10 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference)
21702155
final int apos = sb.pos(i);
21712156
final int alen = sb.size(i) + apos;
21722157
final double[] avals = sb.values(i);
2173-
final int[] aidx = sb.indexes(i);
2158+
final int[] aix = sb.indexes(i);
21742159
for(int k = apos; k < alen; k++) {
2175-
final int c = aidx[k];
2176-
final int outIdx = off + aidx[k];
2160+
final int c = aix[k];
2161+
final int outIdx = off + aix[k];
21772162
final double v = avals[k];
21782163
if(colsWithNan.contains(c))
21792164
retV[outIdx] = 0;
@@ -2185,27 +2170,16 @@ else if(Util.eq(v, Double.NaN))
21852170
}
21862171
}
21872172
else {
2188-
int off = 0;
21892173
final double[] values = _data.getDenseBlockValues();
2190-
for(int i = 0; i < nRow; i++) {
2191-
for(int j = 0; j < nCol; j++) {
2192-
final double v = values[off];
21932174

2194-
if(colsWithNan.contains(j))
2195-
retV[off++] = 0;
2196-
else if(Util.eq(v, Double.NaN))
2197-
retV[off++] = replace - reference[j];
2198-
else
2199-
retV[off++] = v;
2200-
}
2201-
}
2175+
Dictionary.replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, values,
2176+
retV);
22022177
}
2203-
2204-
ret.recomputeNonZeros();
2205-
ret.examSparsity();
2206-
return MatrixBlockDictionary.create(ret);
22072178
}
22082179

2180+
ret.recomputeNonZeros();
2181+
ret.examSparsity();
2182+
return MatrixBlockDictionary.create(ret);
22092183
}
22102184

22112185
@Override
@@ -2277,6 +2251,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
22772251
}
22782252
else
22792253
values = _data.getDenseBlockValues();
2254+
22802255
BigDecimal tmp = BigDecimal.ONE;
22812256
int off = 0;
22822257
for(int i = 0; i < nRow; i++) {
@@ -2286,6 +2261,10 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
22862261
ret[0] = 0;
22872262
return;
22882263
}
2264+
else if(!Double.isFinite(v)) {
2265+
ret[0] = v;
2266+
return;
2267+
}
22892268
tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont);
22902269
}
22912270
}
@@ -2294,7 +2273,8 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
22942273
if(Math.abs(tmp.doubleValue()) == 0)
22952274
ret[0] = 0;
22962275
else if(!Double.isInfinite(ret[0]))
2297-
ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue();
2276+
ret[0] = new BigDecimal(ret[0]).multiply(tmp, cont).doubleValue();
2277+
22982278
}
22992279

23002280
@Override

0 commit comments

Comments
 (0)