Skip to content

Commit 96fd5da

Browse files
committed
[SYSTEMDS-3808] Dictionary Compressed Combine
This commit speedup the combining of dictionaries via custum hashmaps. Closes #2166 Signed-off-by: Sebastian Baunsgaard <[email protected]>
1 parent 809490f commit 96fd5da

File tree

12 files changed

+888
-351
lines changed

12 files changed

+888
-351
lines changed

src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java

Lines changed: 389 additions & 112 deletions
Large diffs are not rendered by default.

src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020
package org.apache.sysds.runtime.compress.estim.encoding;
2121

22-
import java.util.Map;
23-
2422
import org.apache.commons.lang3.tuple.ImmutablePair;
2523
import org.apache.commons.lang3.tuple.Pair;
2624
import org.apache.sysds.runtime.compress.CompressionSettings;
2725
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
26+
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
2827

2928
/** Const encoding for cases where the entire group of columns is the same value */
3029
public class ConstEncoding extends AEncode {
@@ -41,7 +40,7 @@ public IEncode combine(IEncode e) {
4140
}
4241

4342
@Override
44-
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
43+
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
4544
return new ImmutablePair<>(e, null);
4645
}
4746

src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java

Lines changed: 188 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,40 @@
1919

2020
package org.apache.sysds.runtime.compress.estim.encoding;
2121

22-
import java.util.HashMap;
23-
import java.util.Map;
24-
2522
import org.apache.commons.lang3.tuple.ImmutablePair;
2623
import org.apache.commons.lang3.tuple.Pair;
2724
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
2825
import org.apache.sysds.runtime.compress.CompressionSettings;
2926
import org.apache.sysds.runtime.compress.DMLCompressionException;
3027
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
28+
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar;
29+
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToCharPByte;
3130
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
3231
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
3332
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
33+
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
3434

3535
/**
3636
* An Encoding that contains a value on each row of the input.
3737
*/
3838
public class DenseEncoding extends AEncode {
3939

40+
private static boolean zeroWarn = false;
41+
4042
private final AMapToData map;
4143

4244
public DenseEncoding(AMapToData map) {
4345
this.map = map;
4446

4547
if(CompressedMatrixBlock.debug) {
48+
// if(!zeroWarn) {
4649
int[] freq = map.getCounts();
47-
for(int i = 0; i < freq.length; i++) {
48-
if(freq[i] == 0)
49-
throw new DMLCompressionException("Invalid counts in fact contains 0");
50+
for(int i = 0; i < freq.length && !zeroWarn; i++) {
51+
if(freq[i] == 0) {
52+
LOG.warn("Dense encoding contains zero encoding, indicating not all dictionary entries are in use");
53+
zeroWarn = true;
54+
55+
}
5056
}
5157
}
5258
}
@@ -62,7 +68,7 @@ else if(e instanceof SparseEncoding)
6268
}
6369

6470
@Override
65-
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
71+
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
6672
if(e instanceof EmptyEncoding || e instanceof ConstEncoding)
6773
return new ImmutablePair<>(this, null);
6874
else if(e instanceof SparseEncoding)
@@ -106,14 +112,14 @@ private AMapToData assignSparse(SparseEncoding e) {
106112
return ret;
107113
}
108114

109-
private final Pair<IEncode, Map<Integer, Integer>> combineSparseHashMap(final AMapToData ret) {
115+
private final Pair<IEncode, HashMapLongInt> combineSparseHashMap(final AMapToData ret) {
110116
final int size = ret.size();
111-
final Map<Integer, Integer> m = new HashMap<>(size);
117+
final HashMapLongInt m = new HashMapLongInt(100);
112118
for(int r = 0; r < size; r++) {
113119
final int prev = ret.getIndex(r);
114120
final int v = m.size();
115-
final Integer mv = m.putIfAbsent(prev, v);
116-
if(mv == null)
121+
final int mv = m.putIfAbsent(prev, v);
122+
if(mv == -1)
117123
ret.set(r, v);
118124
else
119125
ret.set(r, mv);
@@ -146,28 +152,44 @@ protected DenseEncoding combineDense(final DenseEncoding other) {
146152
final int nVL = lm.getUnique();
147153
final int nVR = rm.getUnique();
148154
final int size = map.size();
149-
final int maxUnique = nVL * nVR;
150-
155+
int maxUnique = nVL * nVR;
156+
final DenseEncoding retE;
151157
final AMapToData ret = MapToFactory.create(size, maxUnique);
152-
153-
if(maxUnique > size && maxUnique > 2048) {
158+
if(maxUnique < Math.max(nVL, nVR)) {// overflow
159+
final HashMapLongInt m = new HashMapLongInt(Math.max(100, size / 100));
160+
retE = combineDenseWithHashMapLong(lm, rm, size, nVL, ret, m);
161+
}
162+
else if(maxUnique > size && maxUnique > 2048) {
154163
// aka there is more maxUnique than rows.
155-
final Map<Integer, Integer> m = new HashMap<>(size);
156-
return combineDenseWithHashMap(lm, rm, size, nVL, ret, m);
164+
final HashMapLongInt m = new HashMapLongInt(Math.max(100, maxUnique / 100));
165+
retE = combineDenseWithHashMap(lm, rm, size, nVL, ret, m);
157166
}
158167
else {
159168
final AMapToData m = MapToFactory.create(maxUnique, maxUnique + 1);
160-
return combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m);
169+
retE = combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m);
170+
}
171+
172+
if(retE.getUnique() < 0) {
173+
String th = this.toString();
174+
String ot = other.toString();
175+
String cm = retE.toString();
176+
177+
if(th.length() > 1000)
178+
th = th.substring(0, 1000);
179+
if(ot.length() > 1000)
180+
ot = ot.substring(0, 1000);
181+
if(cm.length() > 1000)
182+
cm = cm.substring(0, 1000);
183+
throw new DMLCompressionException(
184+
"Failed to combine dense encodings correctly: Number unique values is lower than max input: \n\n" + th
185+
+ "\n\n" + ot + "\n\n" + cm);
161186
}
187+
return retE;
162188
}
163189

164-
private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final DenseEncoding other) {
165-
if(map == other.map) {
166-
LOG.warn("Constructing perfect mapping, this could be optimized to skip hashmap");
167-
final Map<Integer, Integer> m = new HashMap<>(map.size());
168-
for(int i = 0; i < map.getUnique(); i++)
169-
m.put(i * i, i);
170-
return new ImmutablePair<>(this, m); // same object
190+
private Pair<IEncode, HashMapLongInt> combineDenseNoResize(final DenseEncoding other) {
191+
if(map.equals(other.map)) {
192+
return combineSameMapping();
171193
}
172194

173195
final AMapToData lm = map;
@@ -176,40 +198,115 @@ private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final DenseEnc
176198
final int nVL = lm.getUnique();
177199
final int nVR = rm.getUnique();
178200
final int size = map.size();
179-
final int maxUnique = nVL * nVR;
201+
final int maxUnique = (int) Math.min((long) nVL * nVR, (long) size);
180202

181203
final AMapToData ret = MapToFactory.create(size, maxUnique);
182204

183-
final Map<Integer, Integer> m = new HashMap<>(Math.min(size, maxUnique));
205+
final HashMapLongInt m = new HashMapLongInt(Math.max(100, maxUnique / 1000));
184206
return new ImmutablePair<>(combineDenseWithHashMap(lm, rm, size, nVL, ret, m), m);
207+
}
185208

186-
// there can be less unique.
187-
188-
// return new DenseEncoding(ret);
209+
private Pair<IEncode, HashMapLongInt> combineSameMapping() {
210+
LOG.warn("Constructing perfect mapping, this could be optimized to skip hashmap");
211+
final HashMapLongInt m = new HashMapLongInt(Math.max(100, map.size() / 100));
212+
for(int i = 0; i < map.getUnique(); i++)
213+
m.putIfAbsent(i * (map.getUnique() + 1), i);
214+
return new ImmutablePair<>(this, m); // same object
189215
}
190216

191-
private Pair<IEncode, Map<Integer, Integer>> combineSparseNoResize(final SparseEncoding other) {
217+
private Pair<IEncode, HashMapLongInt> combineSparseNoResize(final SparseEncoding other) {
192218
final AMapToData a = assignSparse(other);
193219
return combineSparseHashMap(a);
194220
}
195221

222+
protected final DenseEncoding combineDenseWithHashMapLong(final AMapToData lm, final AMapToData rm, final int size,
223+
final long nVL, final AMapToData ret, HashMapLongInt m) {
224+
if(ret instanceof MapToChar)
225+
for(int r = 0; r < size; r++)
226+
addValHashMapChar((long) lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, (MapToChar) ret);
227+
else
228+
for(int r = 0; r < size; r++)
229+
addValHashMap((long) lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
230+
return new DenseEncoding(ret.resize(m.size()));
231+
}
232+
196233
protected final DenseEncoding combineDenseWithHashMap(final AMapToData lm, final AMapToData rm, final int size,
197-
final int nVL, final AMapToData ret, Map<Integer, Integer> m) {
234+
final int nVL, final AMapToData ret, HashMapLongInt m) {
235+
// JIT compile instance checks.
236+
if(ret instanceof MapToChar)
237+
combineDenseWIthHashMapCharOut(lm, rm, size, nVL, (MapToChar) ret, m);
238+
else if(ret instanceof MapToCharPByte)
239+
combineDenseWIthHashMapPByteOut(lm, rm, size, nVL, (MapToCharPByte) ret, m);
240+
else
241+
combineDenseWithHashMapGeneric(lm, rm, size, nVL, ret, m);
242+
ret.setUnique(m.size());
243+
return new DenseEncoding(ret);
198244

245+
}
246+
247+
private final void combineDenseWIthHashMapPByteOut(final AMapToData lm, final AMapToData rm, final int size,
248+
final int nVL, final MapToCharPByte ret, HashMapLongInt m) {
249+
for(int r = 0; r < size; r++)
250+
addValHashMapCharByte(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
251+
}
252+
253+
private final void combineDenseWIthHashMapCharOut(final AMapToData lm, final AMapToData rm, final int size,
254+
final int nVL, final MapToChar ret, HashMapLongInt m) {
255+
if(lm instanceof MapToChar && rm instanceof MapToChar)
256+
combineDenseWIthHashMapAllChar(lm, rm, size, nVL, ret, m);
257+
else// some other combination
258+
combineDenseWIthHashMapCharOutGeneric(lm, rm, size, nVL, ret, m);
259+
}
260+
261+
private final void combineDenseWIthHashMapCharOutGeneric(final AMapToData lm, final AMapToData rm, final int size,
262+
final int nVL, final MapToChar ret, HashMapLongInt m) {
263+
for(int r = 0; r < size; r++)
264+
addValHashMapChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
265+
}
266+
267+
private final void combineDenseWIthHashMapAllChar(final AMapToData lm, final AMapToData rm, final int size,
268+
final int nVL, final MapToChar ret, HashMapLongInt m) {
269+
final MapToChar lmC = (MapToChar) lm;
270+
final MapToChar rmC = (MapToChar) rm;
271+
for(int r = 0; r < size; r++)
272+
addValHashMapChar(lmC.getIndex(r) + rmC.getIndex(r) * nVL, r, m, ret);
273+
274+
}
275+
276+
protected final void combineDenseWithHashMapGeneric(final AMapToData lm, final AMapToData rm, final int size,
277+
final int nVL, final AMapToData ret, HashMapLongInt m) {
199278
for(int r = 0; r < size; r++)
200279
addValHashMap(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret);
201-
return new DenseEncoding(ret.resize(m.size()));
202280
}
203281

204282
protected final DenseEncoding combineDenseWithMapToData(final AMapToData lm, final AMapToData rm, final int size,
205283
final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) {
284+
if(m instanceof MapToChar)
285+
return combineDenseWithMapToDataToChar(lm, rm, size, nVL, ret, maxUnique, (MapToChar) m);
286+
else
287+
return combineDenseWithMapToDataGeneric(lm, rm, size, nVL, ret, maxUnique, m);
288+
289+
}
290+
291+
protected final DenseEncoding combineDenseWithMapToDataToChar(final AMapToData lm, final AMapToData rm,
292+
final int size, final int nVL, final AMapToData ret, final int maxUnique, final MapToChar m) {
293+
int newUID = 1;
294+
for(int r = 0; r < size; r++)
295+
newUID = addValMapToDataChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret);
296+
ret.setUnique(newUID - 1);
297+
return new DenseEncoding(ret);
298+
}
299+
300+
protected final DenseEncoding combineDenseWithMapToDataGeneric(final AMapToData lm, final AMapToData rm,
301+
final int size, final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) {
206302
int newUID = 1;
207303
for(int r = 0; r < size; r++)
208304
newUID = addValMapToData(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret);
209-
return new DenseEncoding(ret.resize(newUID - 1));
305+
ret.setUnique(newUID - 1);
306+
return new DenseEncoding(ret);
210307
}
211308

212-
protected static int addValMapToData(final int nv, final int r, final AMapToData map, int newId,
309+
protected static int addValMapToDataChar(final int nv, final int r, final MapToChar map, int newId,
213310
final AMapToData d) {
214311
int mv = map.getIndex(nv);
215312
if(mv == 0)
@@ -218,11 +315,56 @@ protected static int addValMapToData(final int nv, final int r, final AMapToData
218315
return newId;
219316
}
220317

221-
protected static void addValHashMap(final int nv, final int r, final Map<Integer, Integer> map,
318+
protected static int addValMapToData(final int nv, final int r, final AMapToData map, int newId,
222319
final AMapToData d) {
320+
int mv = map.getIndex(nv);
321+
if(mv == 0)
322+
mv = map.setAndGet(nv, newId++);
323+
d.set(r, mv - 1);
324+
return newId;
325+
}
326+
327+
protected static void addValHashMap(final int nv, final int r, final HashMapLongInt map, final AMapToData d) {
223328
final int v = map.size();
224-
final Integer mv = map.putIfAbsent(nv, v);
225-
if(mv == null)
329+
final int mv = map.putIfAbsent(nv, v);
330+
if(mv == -1)
331+
d.set(r, v);
332+
else
333+
d.set(r, mv);
334+
}
335+
336+
protected static void addValHashMapChar(final int nv, final int r, final HashMapLongInt map, final MapToChar d) {
337+
final int v = map.size();
338+
final int mv = map.putIfAbsent(nv, v);
339+
if(mv == -1)
340+
d.set(r, v);
341+
else
342+
d.set(r, mv);
343+
}
344+
345+
protected static void addValHashMapCharByte(final int nv, final int r, final HashMapLongInt map,
346+
final MapToCharPByte d) {
347+
final int v = map.size();
348+
final int mv = map.putIfAbsent(nv, v);
349+
if(mv == -1)
350+
d.set(r, v);
351+
else
352+
d.set(r, mv);
353+
}
354+
355+
protected static void addValHashMapChar(final long nv, final int r, final HashMapLongInt map, final MapToChar d) {
356+
final int v = map.size();
357+
final int mv = map.putIfAbsent(nv, v);
358+
if(mv == -1)
359+
d.set(r, v);
360+
else
361+
d.set(r, mv);
362+
}
363+
364+
protected static void addValHashMap(final long nv, final int r, final HashMapLongInt map, final AMapToData d) {
365+
final int v = map.size();
366+
final int mv = map.putIfAbsent(nv, v);
367+
if(mv == -1)
226368
d.set(r, v);
227369
else
228370
d.set(r, mv);
@@ -237,13 +379,18 @@ public int getUnique() {
237379
public EstimationFactors extractFacts(int nRows, double tupleSparsity, double matrixSparsity,
238380
CompressionSettings cs) {
239381
int largestOffs = 0;
240-
241382
int[] counts = map.getCounts();
242383
for(int i = 0; i < counts.length; i++)
243384
if(counts[i] > largestOffs)
244385
largestOffs = counts[i];
245-
else if(counts[i] == 0)
246-
throw new DMLCompressionException("Invalid count of 0 all values should have at least one instance");
386+
else if(counts[i] == 0) {
387+
if(!zeroWarn) {
388+
LOG.warn("Invalid count of 0 all values should have at least one instance index: " + i + " of "
389+
+ counts.length);
390+
zeroWarn = true;
391+
}
392+
counts[i] = 1;
393+
}
247394

248395
if(cs.isRLEAllowed())
249396
return new EstimationFactors(map.getUnique(), nRows, largestOffs, counts, 0, nRows, map.countRuns(), false,

src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020
package org.apache.sysds.runtime.compress.estim.encoding;
2121

22-
import java.util.Map;
23-
2422
import org.apache.commons.lang3.tuple.ImmutablePair;
2523
import org.apache.commons.lang3.tuple.Pair;
2624
import org.apache.sysds.runtime.compress.CompressionSettings;
2725
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
26+
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
2827

2928
/**
3029
* Empty encoding for cases where the entire group of columns is zero
@@ -41,7 +40,7 @@ public IEncode combine(IEncode e) {
4140
}
4241

4342
@Override
44-
public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) {
43+
public Pair<IEncode, HashMapLongInt> combineWithMap(IEncode e) {
4544
return new ImmutablePair<>(e, null);
4645
}
4746

0 commit comments

Comments
 (0)