Skip to content

Commit 51dc823

Browse files
committed
parallel
1 parent 4c0f7eb commit 51dc823

File tree

1 file changed

+62
-18
lines changed

1 file changed

+62
-18
lines changed

src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
import org.apache.sysds.runtime.frame.data.FrameBlock;
5555
import org.apache.sysds.runtime.frame.data.columns.ACompressedArray;
5656
import org.apache.sysds.runtime.frame.data.columns.Array;
57+
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
5758
import org.apache.sysds.runtime.frame.data.columns.DDCArray;
59+
import org.apache.sysds.runtime.frame.data.columns.DoubleArray;
5860
import org.apache.sysds.runtime.frame.data.columns.HashMapToInt;
5961
import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics;
6062
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -136,7 +138,7 @@ private List<AColGroup> singleThread(List<ColumnEncoderComposite> encoders) thro
136138
}
137139
if(ucg.size() > 0)
138140
groups.add(combine(ucg));
139-
141+
140142
return groups;
141143
}
142144

@@ -175,24 +177,24 @@ private int shiftGroups(List<AColGroup> groups) {
175177
final IntArrayList ucCols = new IntArrayList();
176178

177179
// ColIndexFactory.create
178-
for(int i = 0; i < encoders.size(); i++ ){
180+
for(int i = 0; i < encoders.size(); i++) {
179181
// for each encoder ...
180182
ColumnEncoderComposite c = encoders.get(i);
181183
Array<?> a = in.getColumn(c._colID - 1);
182-
if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)){
184+
if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)) {
183185
// if this encoder was part of the uncompressed encoders.
184186
// do not shift the column indexes because we combined all uncompressed columnGroups.
185187
ucCols.appendValue(curCols++);
186188
}
187189
else {
188190
AColGroup g = groups.get(curGroup);
189-
groups.set( curGroup, g.shiftColIndices(curCols));
191+
groups.set(curGroup, g.shiftColIndices(curCols));
190192
curCols += g.getColIndices().size();
191193
}
192194
}
193-
if( ucCols.size() > 0){
194-
int i = groups.size()-1;
195-
AColGroup g =groups.get(i);
195+
if(ucCols.size() > 0) {
196+
int i = groups.size() - 1;
197+
AColGroup g = groups.get(i);
196198
groups.set(i, g.copyAndSet(ColIndexFactory.create(ucCols)));
197199
}
198200
return curCols;
@@ -463,17 +465,13 @@ private <T> boolean uncompressedPassThrough(final Array<T> a) {
463465
return false;// if not booleans
464466
}
465467

466-
private <T> AColGroup passThroughCompressed(final Array<T> a) {
468+
private <T> AColGroup passThroughCompressed(final Array<T> a) throws InterruptedException, ExecutionException {
467469
// only DDC possible currently.
468470
DDCArray<?> aDDC = (DDCArray<?>) a;
469471
Array<?> dict = aDDC.getDict();
470-
double[] vals = new double[dict.size()];
471-
if(a.containsNull())
472-
for(int i = 0; i < dict.size(); i++)
473-
vals[i] = dict.getAsNaNDouble(i);
474-
else
475-
for(int i = 0; i < dict.size(); i++)
476-
vals[i] = dict.getAsDouble(i);
472+
final int dSize = dict.size();
473+
474+
final double[] vals = passThroughCompressedCreateDict(a, dict, dSize);
477475

478476
ADictionary d = Dictionary.create(vals);
479477
AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, aDDC.getMap(), null);
@@ -482,6 +480,52 @@ private <T> AColGroup passThroughCompressed(final Array<T> a) {
482480
return ret;
483481
}
484482

483+
private <T> double[] passThroughCompressedCreateDict(final Array<T> a, Array<?> dict, final int dSize) throws InterruptedException, ExecutionException {
484+
final double[] vals;
485+
final boolean nulls = a.containsNull();
486+
if(dict.getValueType() == ValueType.FP64 && !nulls) {
487+
DoubleArray converted = ((DoubleArray) dict);
488+
vals = converted.get();
489+
}
490+
else if(!nulls) {
491+
DoubleArray converted = ArrayFactory.create(new double[dSize]);
492+
passThroughTransferNoNulls(dict, dSize, converted);
493+
vals = converted.get();
494+
}
495+
else {
496+
vals = passThroughTransferNulls(dict, dSize);
497+
}
498+
return vals;
499+
}
500+
501+
private double[] passThroughTransferNulls(Array<?> dict, final int dSize) {
502+
final double[] vals;
503+
vals = new double[dSize];
504+
for(int i = 0; i < dSize; i++) {
505+
vals[i] = dict.getAsNaNDouble(i);
506+
}
507+
return vals;
508+
}
509+
510+
private void passThroughTransferNoNulls(Array<?> dict, final int dSize, DoubleArray converted) throws InterruptedException, ExecutionException {
511+
if(isParallel() && dSize > 10000){
512+
final int blkz = Math.min(10000 , (dSize + k) / k);
513+
final List<Future<?>> tasks = new ArrayList<>();
514+
for(int i = 0; i < dSize ; i += blkz){
515+
int si = i;
516+
int ei = Math.min(dSize, i + blkz);
517+
tasks.add(pool.submit(() -> {
518+
dict.changeType(converted, si, ei);
519+
}));
520+
}
521+
for(Future<?> t : tasks)
522+
t.get();
523+
}
524+
else {
525+
dict.changeType(converted, 0, dSize);
526+
}
527+
}
528+
485529
private <T> AMapToData createMappingAMapToData(Array<T> a, HashMapToInt<T> map, boolean containsNull)
486530
throws Exception {
487531
final int si = map.size();
@@ -674,8 +718,8 @@ private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws Interrupte
674718
nnz.addAndGet(combinedNNZ);
675719
ret.setNonZeros(combinedNNZ);
676720
if(LOG.isDebugEnabled())
677-
LOG.debug("Combining of : " + ucg.size() + " uncompressed columns Time:" + t);
678-
721+
LOG.debug("Combining of : " + ucg.size() + " uncompressed columns Time: " + t.stop());
722+
679723
return ColGroupUncompressed.create(ret, combinedCols);
680724
}
681725

@@ -708,7 +752,7 @@ private long putInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int il,
708752
final double[] rval = db.values(i);
709753
final int off = db.pos(i);
710754
for(int j = jl; j < ju; j++) {
711-
nnz += (rval[off + j] = ucg.get(j).array.getAsDouble(i)) == 0.0 ? 1 : 0;
755+
nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 1 : 0;
712756
}
713757
}
714758
return nnz;

0 commit comments

Comments
 (0)