Skip to content

Commit 6cd34c2

Browse files
committed
better parallelization
1 parent 578308a commit 6cd34c2

File tree

1 file changed

+68
-52
lines changed

1 file changed

+68
-52
lines changed

src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ public class CompressedEncode {
8282

8383
private final AtomicLong nnz = new AtomicLong();
8484

85+
private static final IColIndex SINGLE_COL_TMP_INDEX = ColIndexFactory.create(1);
86+
8587
private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) {
8688
this.enc = enc;
8789
this.in = in;
@@ -107,8 +109,6 @@ private MatrixBlock apply() throws Exception {
107109
final List<AColGroup> groups = isParallel() ? multiThread(encoders) : singleThread(encoders);
108110
final int cols = shiftGroups(groups);
109111
final CompressedMatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups);
110-
111-
combineUncompressed(mb);
112112
mb.setNonZeros(nnz.get());
113113
logging(mb);
114114
return mb;
@@ -125,16 +125,36 @@ private boolean isParallel() {
125125

126126
private List<AColGroup> singleThread(List<ColumnEncoderComposite> encoders) throws Exception {
127127
List<AColGroup> groups = new ArrayList<>(encoders.size());
128-
for(ColumnEncoderComposite c : encoders)
129-
groups.add(encode(c));
128+
List<ColGroupUncompressedArray> ucg = new ArrayList<>();
129+
for(ColumnEncoderComposite c : encoders) {
130+
AColGroup g = encode(c);
131+
if(g instanceof ColGroupUncompressedArray)
132+
ucg.add((ColGroupUncompressedArray) g);
133+
else
134+
groups.add(g);
135+
}
136+
if(ucg.size() > 0) {
137+
groups.add(combine(ucg));
138+
}
130139
return groups;
131140
}
132141

133142
private List<AColGroup> multiThread(List<ColumnEncoderComposite> encoders) throws Exception {
134143
final List<Future<AColGroup>> tasks = new ArrayList<>(encoders.size());
135-
for(ColumnEncoderComposite c : encoders)
136-
tasks.add(pool.submit(() -> encode(c)));
144+
final List<Future<AColGroup>> ucgTasks = new ArrayList<>();
145+
for(ColumnEncoderComposite c : encoders) {
146+
147+
Array<?> a = in.getColumn(c._colID - 1);
148+
if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a))
149+
ucgTasks.add(pool.submit(() -> encode(c)));
150+
else
151+
tasks.add(pool.submit(() -> encode(c)));
152+
}
137153
final List<AColGroup> groups = new ArrayList<>(encoders.size());
154+
if(!ucgTasks.isEmpty()) {
155+
groups.add(combineFutures(ucgTasks));
156+
}
157+
138158
for(Future<AColGroup> t : tasks)
139159
groups.add(t.get());
140160
return groups;
@@ -383,44 +403,48 @@ private ADictionary createRecodeDictionary(boolean containsNull, int domain) {
383403

384404
@SuppressWarnings("unchecked")
385405
private <T> AColGroup passThrough(ColumnEncoderComposite c) throws Exception {
386-
387-
final IColIndex colIndexes = ColIndexFactory.create(1);
388-
final int colId = c._colID;
389-
final Array<T> a = (Array<T>) in.getColumn(colId - 1);
406+
final int colId = c._colID - 1;
407+
final Array<T> a = (Array<T>) in.getColumn(colId);
390408
if(a instanceof ACompressedArray)
391-
return passThroughCompressed(colIndexes, a);
409+
return passThroughCompressed(a);
410+
else if(uncompressedPassThrough(a))
411+
return new ColGroupUncompressedArray(a, colId, SINGLE_COL_TMP_INDEX);
392412
else
393-
return passThroughNormal(c, colIndexes, a);
413+
return compressingPassThrough(c, a);
394414
}
395415

396-
private <T> AColGroup passThroughNormal(ColumnEncoderComposite c, final IColIndex colIndexes, final Array<T> a)
416+
private <T> AColGroup compressingPassThrough(ColumnEncoderComposite c, final Array<T> a)
397417
throws InterruptedException, ExecutionException, Exception {
398-
// Take a small sample
399-
ArrayCompressionStatistics stats = !inputContainsCompressed ? //
400-
a.statistics(Math.min(1000, a.size())) : null;
418+
boolean containsNull = a.containsNull();
419+
estimateRCDMapSize(c);
420+
HashMapToInt<T> map = (HashMapToInt<T>) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns());
421+
double[] vals = new double[map.size() + (containsNull ? 1 : 0)];
422+
if(containsNull)
423+
vals[map.size()] = Double.NaN;
424+
ValueType t = a.getValueType();
425+
map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k));
426+
ADictionary d = Dictionary.create(vals);
427+
AMapToData m = createMappingAMapToData(a, map, containsNull);
428+
AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, m, null);
429+
nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows()));
430+
return ret;
431+
}
432+
433+
private <T> boolean uncompressedPassThrough(final Array<T> a) {
434+
435+
if(a.getValueType() != ValueType.BOOLEAN) {
401436

402-
if(a.getValueType() != ValueType.BOOLEAN // if not booleans
403-
&& (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) {
404-
return new ColGroupUncompressedArray(a, c._colID - 1, colIndexes);
437+
ArrayCompressionStatistics stats = !inputContainsCompressed ? //
438+
a.statistics(Math.min(1000, a.size())) : null;
439+
return stats == null // if some columns already are compressed then most likely we do not need to
440+
|| !stats.shouldCompress // if we should compress ... lets
441+
|| stats.valueType != a.getValueType(); // if the compression says change value type, then do not do it.
405442
}
406-
else {
407-
boolean containsNull = a.containsNull();
408-
estimateRCDMapSize(c);
409-
HashMapToInt<T> map = (HashMapToInt<T>) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns());
410-
double[] vals = new double[map.size() + (containsNull ? 1 : 0)];
411-
if(containsNull)
412-
vals[map.size()] = Double.NaN;
413-
ValueType t = a.getValueType();
414-
map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k));
415-
ADictionary d = Dictionary.create(vals);
416-
AMapToData m = createMappingAMapToData(a, map, containsNull);
417-
AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
418-
nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows()));
419-
return ret;
420-
}
421-
}
422-
423-
private <T> AColGroup passThroughCompressed(final IColIndex colIndexes, final Array<T> a) {
443+
444+
return false;// if not booleans
445+
}
446+
447+
private <T> AColGroup passThroughCompressed(final Array<T> a) {
424448
// only DDC possible currently.
425449
DDCArray<?> aDDC = (DDCArray<?>) a;
426450
Array<?> dict = aDDC.getDict();
@@ -433,7 +457,7 @@ private <T> AColGroup passThroughCompressed(final IColIndex colIndexes, final Ar
433457
vals[i] = dict.getAsDouble(i);
434458

435459
ADictionary d = Dictionary.create(vals);
436-
AColGroup ret = ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null);
460+
AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, aDDC.getMap(), null);
437461

438462
nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows()));
439463
return ret;
@@ -606,21 +630,12 @@ private <T> void estimateRCDMapSize(ColumnEncoderComposite c) {
606630
c._estNumDistincts = estDistCount;
607631
}
608632

609-
private void combineUncompressed(CompressedMatrixBlock mb) throws InterruptedException, ExecutionException {
610-
611-
List<ColGroupUncompressedArray> ucg = new ArrayList<>();
612-
List<AColGroup> ret = new ArrayList<>();
613-
for(AColGroup g : mb.getColGroups()) {
614-
if(g instanceof ColGroupUncompressedArray)
615-
ucg.add((ColGroupUncompressedArray) g);
616-
else
617-
ret.add(g);
633+
private AColGroup combineFutures(List<Future<AColGroup>> ucgTasks) throws InterruptedException, ExecutionException {
634+
List<ColGroupUncompressedArray> ucg = new ArrayList<>(ucgTasks.size());
635+
for(Future<AColGroup> g : ucgTasks) {
636+
ucg.add((ColGroupUncompressedArray) g.get());
618637
}
619-
if(ucg.size() > 0) {
620-
ret.add(combine(ucg));
621-
nnz.addAndGet(ret.get(ret.size() - 1).getNumberNonZeros(in.getNumRows()));
622-
}
623-
mb.allocateColGroupList(ret);
638+
return combine(ucg);
624639
}
625640

626641
private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws InterruptedException, ExecutionException {
@@ -640,6 +655,7 @@ private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws Interrupte
640655

641656
nnz.addAndGet(combinedNNZ);
642657
ret.setNonZeros(combinedNNZ);
658+
643659
return ColGroupUncompressed.create(ret, combinedCols);
644660
}
645661

0 commit comments

Comments
 (0)