Skip to content

Commit 174ed30

Browse files
committed
corrections
1 parent 81c5a51 commit 174ed30

File tree

10 files changed

+325
-144
lines changed

10 files changed

+325
-144
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,48 @@ public class IdentityDictionary extends AIdentityDictionary {
5151
*
5252
* @param nRowCol The number of rows and columns in this identity matrix.
5353
*/
54-
public IdentityDictionary(int nRowCol) {
54+
private IdentityDictionary(int nRowCol) {
5555
super(nRowCol);
5656
}
5757

58+
/**
59+
* Create an identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the
60+
* structure is known to have certain properties.
61+
*
62+
* @param nRowCol The number of rows and columns in this identity matrix.
63+
*/
64+
public static IDictionary create(int nRowCol) {
65+
return create(nRowCol, false);
66+
}
67+
5868
/**
5969
* Create an identity matrix dictionary, It behaves as if allocated a Sparse Matrix block but exploits that the
6070
* structure is known to have certain properties.
6171
*
6272
* @param nRowCol The number of rows and columns in this identity matrix.
6373
* @param withEmpty If the matrix should contain an empty row in the end.
6474
*/
65-
public IdentityDictionary(int nRowCol, boolean withEmpty) {
75+
private IdentityDictionary(int nRowCol, boolean withEmpty) {
6676
super(nRowCol, withEmpty);
6777
}
6878

79+
/**
80+
* Create an identity matrix dictionary, It behaves as if allocated a Sparse Matrix block but exploits that the
81+
* structure is known to have certain properties.
82+
*
83+
* @param nRowCol The number of rows and columns in this identity matrix.
84+
* @param withEmpty If the matrix should contain an empty row in the end.
85+
*/
86+
public static IDictionary create(int nRowCol, boolean withEmpty) {
87+
if(nRowCol == 1) {
88+
if(withEmpty)
89+
return new Dictionary(new double[] {1, 0});
90+
else
91+
return new Dictionary(new double[] {1});
92+
}
93+
return new IdentityDictionary(nRowCol, withEmpty);
94+
}
95+
6996
@Override
7097
public double[] getValues() {
7198
if(nRowCol < 3) {
@@ -129,7 +156,6 @@ public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) {
129156
}
130157
}
131158

132-
133159
@Override
134160
public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) {
135161
boolean same = false;
@@ -233,7 +259,6 @@ public void colProduct(double[] res, int[] counts, IColIndex colIndexes) {
233259
}
234260
}
235261

236-
237262
@Override
238263
public double sum(int[] counts, int ncol) {
239264
// number of rows, change this.
@@ -255,7 +280,7 @@ public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNum
255280
if(idxStart == 0 && idxEnd == nRowCol)
256281
return new IdentityDictionary(nRowCol, withEmpty);
257282
else
258-
return new IdentityDictionarySlice(nRowCol, withEmpty, idxStart, idxEnd);
283+
return IdentityDictionarySlice.create(nRowCol, withEmpty, idxStart, idxEnd);
259284
}
260285

261286
@Override
@@ -265,6 +290,8 @@ public long getNumberNonZeros(int[] counts, int nCol) {
265290

266291
@Override
267292
public int[] countNNZZeroColumns(int[] counts) {
293+
if(withEmpty)
294+
return Arrays.copyOf(counts, nRowCol); // one less.
268295
return counts; // interesting ... but true.
269296
}
270297

@@ -322,28 +349,25 @@ private void addToEntryVectorizedNorm(double[] v, int f1, int f2, int f3, int f4
322349
v[t8 * nCol + f8] += 1;
323350
}
324351

325-
@Override
326-
public MatrixBlockDictionary getMBDict(){
352+
@Override
353+
public MatrixBlockDictionary getMBDict() {
327354
return getMBDict(nRowCol);
328355
}
329356

330357
@Override
331358
public MatrixBlockDictionary createMBDict(int nCol) {
332-
333359
if(withEmpty) {
334360
final SparseBlock sb = SparseBlockFactory.createIdentityMatrixWithEmptyRow(nRowCol);
335361
final MatrixBlock identity = new MatrixBlock(nRowCol + 1, nRowCol, nRowCol, sb);
336362
return new MatrixBlockDictionary(identity);
337363
}
338364
else {
339-
340365
final SparseBlock sb = SparseBlockFactory.createIdentityMatrix(nRowCol);
341366
final MatrixBlock identity = new MatrixBlock(nRowCol, nRowCol, nRowCol, sb);
342367
return new MatrixBlockDictionary(identity);
343368
}
344369
}
345370

346-
347371
@Override
348372
public void write(DataOutput out) throws IOException {
349373
out.writeByte(DictionaryFactory.Type.IDENTITY.ordinal());
@@ -403,7 +427,7 @@ public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColInd
403427
@Override
404428
public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) {
405429
// similar to fused transpose left into right locations.
406-
430+
407431
final int leftSide = rowsLeft.size();
408432
final int colsOut = result.getNumColumns();
409433
final int commonDim = Math.min(left.length / leftSide, nRowCol);
@@ -431,7 +455,6 @@ public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex cols
431455
}
432456
}
433457

434-
435458
@Override
436459
public boolean equals(IDictionary o) {
437460
if(o instanceof IdentityDictionary && //

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,34 @@ public class IdentityDictionarySlice extends AIdentityDictionary {
4949
*/
5050
public IdentityDictionarySlice(int nRowCol, boolean withEmpty, int l, int u) {
5151
super(nRowCol, withEmpty);
52-
if(u > nRowCol || l < 0 || l >= u)
53-
throw new DMLRuntimeException("Invalid slice Identity: " + nRowCol + " range: " + l + "--" + u);
5452
this.l = l;
5553
this.u = u;
5654
}
5755

56+
/**
57+
* Create a Identity matrix dictionary slice (if other groups are not more applicable). It behaves as if allocated a
58+
* Sparse Matrix block but exploits that the structure is known to have certain properties.
59+
*
60+
* @param nRowCol the number of rows and columns in this identity matrix.
61+
* @param withEmpty If the matrix should contain an empty row in the end.
62+
* @param l the index lower to start at
63+
* @param u the index upper to end at (not inclusive)
64+
*/
65+
public static IDictionary create(int nRowCol, boolean withEmpty, int l, int u) {
66+
if(u > nRowCol || l < 0 || l >= u)
67+
throw new DMLRuntimeException("Invalid slice Identity: " + nRowCol + " range: " + l + "--" + u);
68+
if(nRowCol == 1) {
69+
if(withEmpty)
70+
return new Dictionary(new double[] {1, 0});
71+
else
72+
return new Dictionary(new double[] {1});
73+
}
74+
else if(l == 0 && u == nRowCol)
75+
return IdentityDictionary.create(nRowCol, withEmpty);
76+
else
77+
return new IdentityDictionarySlice(nRowCol, withEmpty, l, u);
78+
}
79+
5880
@Override
5981
public double[] getValues() {
6082
LOG.warn("Should not call getValues on Identity Dictionary");
@@ -96,9 +118,15 @@ public static long getInMemorySize(int numberColumns) {
96118

97119
@Override
98120
public double[] aggregateRows(Builtin fn, int nCol) {
99-
double[] ret = new double[nRowCol];
100-
Arrays.fill(ret, l, u, fn.execute(1, 0));
101-
return ret;
121+
double[] ret = new double[nRowCol + (withEmpty ? 1 : 0)];
122+
if(l + 1 == u) {
123+
ret[l] = 1;
124+
return ret;
125+
}
126+
else {
127+
Arrays.fill(ret, l, u, fn.execute(1, 0));
128+
return ret;
129+
}
102130
}
103131

104132
@Override
@@ -139,19 +167,16 @@ public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) {
139167

140168
@Override
141169
public double[] sumAllRowsToDoubleWithReference(double[] reference) {
142-
double[] ret = new double[getNumberOfValues(reference.length)];
170+
final double[] ret = new double[getNumberOfValues(reference.length)];
143171
double refSum = 0;
144172
for(int i = 0; i < reference.length; i++)
145173
refSum += reference[i];
146-
for(int i = 0; i < ret.length; i++) {
147-
if(i < l || i > u)
148-
ret[i] = refSum;
149-
else
150-
ret[i] = 1 + refSum;
151-
}
152-
153-
if(withEmpty)
154-
ret[ret.length - 1] += -1;
174+
for(int i = 0; i < l; i++)
175+
ret[i] = refSum;
176+
for(int i = l; i < u; i++)
177+
ret[i] = 1 + refSum;
178+
for(int i = u; i < ret.length; i++)
179+
ret[i] = refSum;
155180
return ret;
156181
}
157182

@@ -180,9 +205,8 @@ public void colSum(double[] c, int[] counts, IColIndex colIndexes) {
180205

181206
@Override
182207
public double sum(int[] counts, int ncol) {
183-
int end = withEmpty && u == ncol ? u - 1 : u;
184208
double s = 0.0;
185-
for(int i = l; i < end; i++)
209+
for(int i = l; i < u; i++)
186210
s += counts[i];
187211
return s;
188212
}
@@ -241,22 +265,22 @@ public void addToEntry(final double[] v, final int fr, final int to, final int n
241265
public boolean equals(IDictionary o) {
242266
if(o instanceof IdentityDictionarySlice) {
243267
IdentityDictionarySlice os = ((IdentityDictionarySlice) o);
244-
return os.nRowCol == nRowCol && os.l == l && os.u == u;
268+
return os.nRowCol == nRowCol && os.l == l && os.u == u && withEmpty == os.withEmpty;
245269
}
246270
else if(o instanceof IdentityDictionary)
247271
return false;
248272
else
249273
return getMBDict().equals(o);
250274
}
251275

252-
@Override
253-
public MatrixBlockDictionary getMBDict(){
276+
@Override
277+
public MatrixBlockDictionary getMBDict() {
254278
return getMBDict(nRowCol);
255279
}
256280

257281
@Override
258282
public MatrixBlockDictionary createMBDict(int nCol) {
259-
MatrixBlock identity = new MatrixBlock(nRowCol + (withEmpty ? 1 : 0), u - l, true);
283+
MatrixBlock identity = new MatrixBlock(nRowCol + (withEmpty ? 1 : 0), u - l, true);
260284
for(int i = l; i < u; i++)
261285
identity.set(i, i - l, 1.0);
262286
return new MatrixBlockDictionary(identity);

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ public int[] countNNZZeroColumns(int[] counts) {
15801580
final int aix[] = sb.indexes(i);
15811581
for(int j = apos; j < alen; j++) {
15821582

1583-
ret[aix[i]] += counts[i];
1583+
ret[aix[j]] += counts[i];
15841584
}
15851585
}
15861586
}

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.io.DataOutput;
2424
import java.io.IOException;
2525

26-
import org.apache.commons.lang3.NotImplementedException;
2726
import org.apache.sysds.runtime.functionobjects.Builtin;
2827
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2928
import org.apache.sysds.utils.MemoryEstimates;
@@ -47,11 +46,24 @@ protected QDictionary(byte[] values, double scale, int nCol) {
4746
_nCol = nCol;
4847
}
4948

49+
public static QDictionary create(byte[] values, double scale, int nCol, boolean check) {
50+
if(scale == 0)
51+
return null;
52+
if(check) {
53+
boolean containsOnlyZero = true;
54+
for(int i = 0; i < values.length && containsOnlyZero; i++) {
55+
if(values[i] != 0)
56+
containsOnlyZero = false;
57+
}
58+
if(containsOnlyZero)
59+
return null;
60+
}
61+
return new QDictionary(values, scale, nCol);
62+
}
63+
5064
@Override
5165
public double[] getValues() {
52-
if(_values == null) {
53-
return new double[0];
54-
}
66+
5567
double[] res = new double[_values.length];
5668
for(int i = 0; i < _values.length; i++) {
5769
res[i] = getValue(i);
@@ -69,18 +81,6 @@ public final double getValue(int r, int c, int nCol) {
6981
return _values[r * nCol + c] * _scale;
7082
}
7183

72-
public byte getValueByte(int i) {
73-
return _values[i];
74-
}
75-
76-
public byte[] getValuesByte() {
77-
return _values;
78-
}
79-
80-
public double getScale() {
81-
return _scale;
82-
}
83-
8484
@Override
8585
public long getInMemorySize() {
8686
// object + values array + double
@@ -102,26 +102,6 @@ public double aggregate(double init, Builtin fn) {
102102
return ret;
103103
}
104104

105-
@Override
106-
public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) {
107-
throw new NotImplementedException();
108-
}
109-
110-
@Override
111-
public double[] aggregateRows(Builtin fn, final int nCol) {
112-
if(nCol == 1)
113-
return getValues();
114-
final int nRows = _values.length / nCol;
115-
double[] res = new double[nRows];
116-
for(int i = 0; i < nRows; i++) {
117-
final int off = i * nCol;
118-
res[i] = _values[off];
119-
for(int j = off + 1; j < off + nCol; j++)
120-
res[i] = fn.execute(res[i], _values[j] * _scale);
121-
}
122-
return res;
123-
}
124-
125105
private int size() {
126106
return _values.length;
127107
}
@@ -159,7 +139,7 @@ public long getExactSizeOnDisk() {
159139

160140
@Override
161141
public int getNumberOfValues(int nCol) {
162-
return (_values == null) ? 0 : _values.length / nCol;
142+
return _values.length / nCol;
163143
}
164144

165145
@Override
@@ -185,10 +165,7 @@ public double[] sumAllRowsToDoubleSq(int nrColumns) {
185165
}
186166

187167
private double sumRow(int k, int nrColumns) {
188-
if(_values == null)
189-
return 0;
190168
int valOff = k * nrColumns;
191-
192169
int res = 0;
193170
for(int i = 0; i < nrColumns; i++) {
194171
res += _values[valOff + i];
@@ -197,8 +174,6 @@ private double sumRow(int k, int nrColumns) {
197174
}
198175

199176
private double sumRowSq(int k, int nrColumns) {
200-
if(_values == null)
201-
return 0;
202177
int valOff = k * nrColumns;
203178
double res = 0.0;
204179
for(int i = 0; i < nrColumns; i++)
@@ -215,11 +190,6 @@ public String getString(int colIndexes) {
215190
return sb.toString();
216191
}
217192

218-
public Dictionary makeDoubleDictionary() {
219-
double[] doubleValues = getValues();
220-
return Dictionary.create(doubleValues);
221-
}
222-
223193
public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) {
224194
int numberTuples = getNumberOfValues(previousNumberOfColumns);
225195
int tupleLengthAfter = idxEnd - idxStart;
@@ -274,7 +244,11 @@ public DictType getDictType() {
274244

275245
@Override
276246
public double getSparsity() {
277-
throw new NotImplementedException();
247+
int nnz = 0;
248+
for(int i = 0; i < _values.length; i++) {
249+
nnz += _values[i] == 0 ? 0 : 1;
250+
}
251+
return (double) nnz / _values.length;
278252
}
279253

280254
@Override

0 commit comments

Comments
 (0)