Skip to content

Commit d39895a

Browse files
committed
count nnz
1 parent fbee615 commit d39895a

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -632,19 +632,19 @@ private AColGroup combine(List<ColGroupUncompressedArray> ucg) throws Interrupte
632632
final DenseBlock db = ret.getDenseBlock();
633633
final int nrow = in.getNumRows();
634634
final int ncol = combinedCols.size();
635+
final long combinedNNZ;
635636
if(isParallel() && (long) nrow * ncol > 10000 && nrow > 512)
636-
parallelPutInto(ucg, db, nrow, ncol);
637+
combinedNNZ = parallelPutInto(ucg, db, nrow, ncol);
637638
else
638-
putInto(ucg, db, 0, nrow, 0, ncol);
639-
640-
ret.recomputeNonZeros(k);
639+
combinedNNZ = putInto(ucg, db, 0, nrow, 0, ncol);
641640

641+
nnz.addAndGet(combinedNNZ);
642642
return ColGroupUncompressed.create(ret, combinedCols);
643643
}
644644

645-
private void parallelPutInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int nrow, int ncol)
645+
private long parallelPutInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int nrow, int ncol)
646646
throws InterruptedException, ExecutionException {
647-
List<Future<?>> tasks = new ArrayList<>();
647+
List<Future<Long>> tasks = new ArrayList<>();
648648

649649
final int iblk = Math.max(512, nrow / k);
650650
final int jblk = Math.min(128, ncol);
@@ -655,23 +655,26 @@ private void parallelPutInto(List<ColGroupUncompressedArray> ucg, DenseBlock db,
655655
int sj = j;
656656
int ej = Math.min(ncol, jblk + j);
657657
tasks.add(pool.submit(() -> {
658-
putInto(ucg, db, si, ei, sj, ej);
658+
return putInto(ucg, db, si, ei, sj, ej);
659659
}));
660660
}
661661
}
662-
663-
for(Future<?> t : tasks)
664-
t.get();
662+
long nnz = 0;
663+
for(Future<Long> t : tasks)
664+
nnz += t.get();
665+
return nnz;
665666
}
666667

667-
private void putInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int il, int iu, int jl, int ju) {
668+
private long putInto(List<ColGroupUncompressedArray> ucg, DenseBlock db, int il, int iu, int jl, int ju) {
669+
long nnz = 0;
668670
for(int i = il; i < iu; i++) {
669671
final double[] rval = db.values(i);
670672
final int off = db.pos(i);
671673
for(int j = jl; j < ju; j++) {
672-
rval[off + j] = ucg.get(j).array.getAsDouble(i);
674+
nnz += (rval[off + j] = ucg.get(j).array.getAsDouble(i)) == 0.0 ? 1 : 0;
673675
}
674676
}
677+
return nnz;
675678
}
676679

677680
private void logging(MatrixBlock mb) {

0 commit comments

Comments
 (0)