@@ -632,19 +632,19 @@ private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws Interrupte
632632 final DenseBlock db = ret .getDenseBlock ();
633633 final int nrow = in .getNumRows ();
634634 final int ncol = combinedCols .size ();
635+ final long combinedNNZ ;
635636 if (isParallel () && (long ) nrow * ncol > 10000 && nrow > 512 )
636- parallelPutInto (ucg , db , nrow , ncol );
637+ combinedNNZ = parallelPutInto (ucg , db , nrow , ncol );
637638 else
638- putInto (ucg , db , 0 , nrow , 0 , ncol );
639-
640- ret .recomputeNonZeros (k );
639+ combinedNNZ = putInto (ucg , db , 0 , nrow , 0 , ncol );
641640
641+ nnz .addAndGet (combinedNNZ );
642642 return ColGroupUncompressed .create (ret , combinedCols );
643643 }
644644
645- private void parallelPutInto (List <ColGroupUncompressedArray > ucg , DenseBlock db , int nrow , int ncol )
645+ private long parallelPutInto (List <ColGroupUncompressedArray > ucg , DenseBlock db , int nrow , int ncol )
646646 throws InterruptedException , ExecutionException {
647- List <Future <? >> tasks = new ArrayList <>();
647+ List <Future <Long >> tasks = new ArrayList <>();
648648
649649 final int iblk = Math .max (512 , nrow / k );
650650 final int jblk = Math .min (128 , ncol );
@@ -655,23 +655,26 @@ private void parallelPutInto(List<ColGroupUncompressedArray> ucg, DenseBlock db,
655655 int sj = j ;
656656 int ej = Math .min (ncol , jblk + j );
657657 tasks .add (pool .submit (() -> {
658- putInto (ucg , db , si , ei , sj , ej );
658+ return putInto (ucg , db , si , ei , sj , ej );
659659 }));
660660 }
661661 }
662-
663- for (Future <?> t : tasks )
664- t .get ();
662+ long nnz = 0 ;
663+ for (Future <Long > t : tasks )
664+ nnz += t .get ();
665+ return nnz ;
665666 }
666667
667- private void putInto (List <ColGroupUncompressedArray > ucg , DenseBlock db , int il , int iu , int jl , int ju ) {
668+ private long putInto (List <ColGroupUncompressedArray > ucg , DenseBlock db , int il , int iu , int jl , int ju ) {
669+ long nnz = 0 ;
668670 for (int i = il ; i < iu ; i ++) {
669671 final double [] rval = db .values (i );
670672 final int off = db .pos (i );
671673 for (int j = jl ; j < ju ; j ++) {
672- rval [off + j ] = ucg .get (j ).array .getAsDouble (i );
674+ nnz += ( rval [off + j ] = ucg .get (j ).array .getAsDouble (i )) == 0.0 ? 1 : 0 ;
673675 }
674676 }
677+ return nnz ;
675678 }
676679
677680 private void logging (MatrixBlock mb ) {
0 commit comments