Skip to content

Commit b6adff8

Browse files
committed
[MINOR] CompressedMatrixBlock parallel nonzero count
1 parent 85331dc commit b6adff8

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.ArrayList;
2929
import java.util.Iterator;
3030
import java.util.List;
31+
import java.util.concurrent.ExecutorService;
3132
import java.util.concurrent.Future;
3233

3334
import org.apache.commons.lang3.NotImplementedException;
@@ -88,6 +89,7 @@
8889
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
8990
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
9091
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
92+
import org.apache.sysds.runtime.util.CommonThreadPool;
9193
import org.apache.sysds.runtime.util.IndexRange;
9294
import org.apache.sysds.utils.DMLCompressionStatistics;
9395
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
@@ -319,6 +321,35 @@ public long recomputeNonZeros() {
319321
return nonZeros;
320322
}
321323

324+
@Override
325+
public long recomputeNonZeros(int k) {
326+
if(k <= 1 || isOverlapping() || _colGroups.size() <= 1)
327+
return recomputeNonZeros();
328+
329+
final ExecutorService pool = CommonThreadPool.get(k);
330+
try {
331+
List<Future<Long>> tasks = new ArrayList<>();
332+
for(AColGroup g : _colGroups)
333+
tasks.add(pool.submit(() -> g.getNumberNonZeros(rlen)));
334+
335+
long nnz = 0;
336+
for(Future<Long> t : tasks)
337+
nnz += t.get();
338+
nonZeros = nnz;
339+
}
340+
catch(Exception e) {
341+
throw new DMLRuntimeException("Failed to count non zeros", e);
342+
}
343+
finally {
344+
pool.shutdown();
345+
}
346+
347+
if(nonZeros == 0) // If there is no nonzeros then reallocate into single empty column group.
348+
allocateColGroup(ColGroupEmpty.create(getNumColumns()));
349+
350+
return nonZeros;
351+
}
352+
322353
@Override
323354
public long recomputeNonZeros(int rl, int ru) {
324355
throw new NotImplementedException();

0 commit comments

Comments
 (0)