5454import org .apache .sysds .runtime .frame .data .FrameBlock ;
5555import org .apache .sysds .runtime .frame .data .columns .ACompressedArray ;
5656import org .apache .sysds .runtime .frame .data .columns .Array ;
57+ import org .apache .sysds .runtime .frame .data .columns .ArrayFactory ;
5758import org .apache .sysds .runtime .frame .data .columns .DDCArray ;
59+ import org .apache .sysds .runtime .frame .data .columns .DoubleArray ;
5860import org .apache .sysds .runtime .frame .data .columns .HashMapToInt ;
5961import org .apache .sysds .runtime .frame .data .compress .ArrayCompressionStatistics ;
6062import org .apache .sysds .runtime .matrix .data .MatrixBlock ;
@@ -136,7 +138,7 @@ private List<AColGroup> singleThread(List<ColumnEncoderComposite> encoders) thro
136138 }
137139 if (ucg .size () > 0 )
138140 groups .add (combine (ucg ));
139-
141+
140142 return groups ;
141143 }
142144
@@ -175,24 +177,24 @@ private int shiftGroups(List<AColGroup> groups) {
175177 final IntArrayList ucCols = new IntArrayList ();
176178
177179 // ColIndexFactory.create
178- for (int i = 0 ; i < encoders .size (); i ++ ) {
180+ for (int i = 0 ; i < encoders .size (); i ++) {
179181 // for each encoder ...
180182 ColumnEncoderComposite c = encoders .get (i );
181183 Array <?> a = in .getColumn (c ._colID - 1 );
182- if (c .isPassThrough () && !(a instanceof ACompressedArray ) && uncompressedPassThrough (a )){
184+ if (c .isPassThrough () && !(a instanceof ACompressedArray ) && uncompressedPassThrough (a )) {
183185 // if this encoder was part of the uncompressed encoders.
184186 // do not shift the column indexes because we combined all uncompressed columnGroups.
185187 ucCols .appendValue (curCols ++);
186188 }
187189 else {
188190 AColGroup g = groups .get (curGroup );
189- groups .set ( curGroup , g .shiftColIndices (curCols ));
191+ groups .set (curGroup , g .shiftColIndices (curCols ));
190192 curCols += g .getColIndices ().size ();
191193 }
192194 }
193- if ( ucCols .size () > 0 ){
194- int i = groups .size ()- 1 ;
195- AColGroup g =groups .get (i );
195+ if (ucCols .size () > 0 ) {
196+ int i = groups .size () - 1 ;
197+ AColGroup g = groups .get (i );
196198 groups .set (i , g .copyAndSet (ColIndexFactory .create (ucCols )));
197199 }
198200 return curCols ;
@@ -463,17 +465,13 @@ private <T> boolean uncompressedPassThrough(final Array<T> a) {
463465 return false ;// if not booleans
464466 }
465467
466- private <T > AColGroup passThroughCompressed (final Array <T > a ) {
468+ private <T > AColGroup passThroughCompressed (final Array <T > a ) throws InterruptedException , ExecutionException {
467469 // only DDC possible currently.
468470 DDCArray <?> aDDC = (DDCArray <?>) a ;
469471 Array <?> dict = aDDC .getDict ();
470- double [] vals = new double [dict .size ()];
471- if (a .containsNull ())
472- for (int i = 0 ; i < dict .size (); i ++)
473- vals [i ] = dict .getAsNaNDouble (i );
474- else
475- for (int i = 0 ; i < dict .size (); i ++)
476- vals [i ] = dict .getAsDouble (i );
472+ final int dSize = dict .size ();
473+
474+ final double [] vals = passThroughCompressedCreateDict (a , dict , dSize );
477475
478476 ADictionary d = Dictionary .create (vals );
479477 AColGroup ret = ColGroupDDC .create (SINGLE_COL_TMP_INDEX , d , aDDC .getMap (), null );
@@ -482,6 +480,52 @@ private <T> AColGroup passThroughCompressed(final Array<T> a) {
482480 return ret ;
483481 }
484482
483+ private <T > double [] passThroughCompressedCreateDict (final Array <T > a , Array <?> dict , final int dSize ) throws InterruptedException , ExecutionException {
484+ final double [] vals ;
485+ final boolean nulls = a .containsNull ();
486+ if (dict .getValueType () == ValueType .FP64 && !nulls ) {
487+ DoubleArray converted = ((DoubleArray ) dict );
488+ vals = converted .get ();
489+ }
490+ else if (!nulls ) {
491+ DoubleArray converted = ArrayFactory .create (new double [dSize ]);
492+ passThroughTransferNoNulls (dict , dSize , converted );
493+ vals = converted .get ();
494+ }
495+ else {
496+ vals = passThroughTransferNulls (dict , dSize );
497+ }
498+ return vals ;
499+ }
500+
501+ private double [] passThroughTransferNulls (Array <?> dict , final int dSize ) {
502+ final double [] vals ;
503+ vals = new double [dSize ];
504+ for (int i = 0 ; i < dSize ; i ++) {
505+ vals [i ] = dict .getAsNaNDouble (i );
506+ }
507+ return vals ;
508+ }
509+
510+ private void passThroughTransferNoNulls (Array <?> dict , final int dSize , DoubleArray converted ) throws InterruptedException , ExecutionException {
511+ if (isParallel () && dSize > 10000 ){
512+ final int blkz = Math .min (10000 , (dSize + k ) / k );
513+ final List <Future <?>> tasks = new ArrayList <>();
514+ for (int i = 0 ; i < dSize ; i += blkz ){
515+ int si = i ;
516+ int ei = Math .min (dSize , i + blkz );
517+ tasks .add (pool .submit (() -> {
518+ dict .changeType (converted , si , ei );
519+ }));
520+ }
521+ for (Future <?> t : tasks )
522+ t .get ();
523+ }
524+ else {
525+ dict .changeType (converted , 0 , dSize );
526+ }
527+ }
528+
485529 private <T > AMapToData createMappingAMapToData (Array <T > a , HashMapToInt <T > map , boolean containsNull )
486530 throws Exception {
487531 final int si = map .size ();
@@ -674,8 +718,8 @@ private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws Interrupte
674718 nnz .addAndGet (combinedNNZ );
675719 ret .setNonZeros (combinedNNZ );
676720 if (LOG .isDebugEnabled ())
677- LOG .debug ("Combining of : " + ucg .size () + " uncompressed columns Time:" + t );
678-
721+ LOG .debug ("Combining of : " + ucg .size () + " uncompressed columns Time: " + t . stop () );
722+
679723 return ColGroupUncompressed .create (ret , combinedCols );
680724 }
681725
@@ -708,7 +752,7 @@ private long putInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int il,
708752 final double [] rval = db .values (i );
709753 final int off = db .pos (i );
710754 for (int j = jl ; j < ju ; j ++) {
711- nnz += (rval [off + j ] = ucg .get (j ).array .getAsDouble (i )) == 0.0 ? 1 : 0 ;
755+ nnz += (rval [off + j ] = ucg .get (j ).array .getAsNaNDouble (i )) == 0.0 ? 1 : 0 ;
712756 }
713757 }
714758 return nnz ;
0 commit comments