Skip to content

Commit 7b34a67

Browse files
committed
[SYSTEMDS-3896] Leverage SIMD Vector API for Counting NNZ
This patch leverages the new Vector API for the core primitive of counting the number of non-zeros (which is still single-threaded because usually done for chunks as part of multi-threaded tasks). For single-threaded computeNnz on an 8GB dense matrix after JIT compilation, this patch improved performance from 1100ms to 850ms.
1 parent ca8d209 commit 7b34a67

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,15 @@
5858
import org.apache.sysds.runtime.meta.TensorCharacteristics;
5959
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
6060

61+
import jdk.incubator.vector.DoubleVector;
62+
import jdk.incubator.vector.VectorSpecies;
63+
6164
public class UtilFunctions {
6265
protected static final Log LOG = LogFactory.getLog(UtilFunctions.class.getName());
66+
private static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
67+
private static final int vLen = SPECIES.length();
6368

69+
6470
private UtilFunctions(){
6571
// empty private constructor
6672
// making all calls static
@@ -876,25 +882,14 @@ public static boolean isNonZero(Object obj) {
876882
public static int computeNnz(final double[] a, final int ai, final int len) {
877883
int lnnz = 0;
878884
final int end = ai + len;
879-
final int h = (end - ai) % 8;
885+
final int rest = (end - ai) % vLen;
880886

881-
for(int i = ai; i < ai + h; i++)
887+
for(int i = ai; i < ai + rest; i++)
882888
lnnz += (a[i] != 0.0) ? 1 : 0;
883-
for(int i = ai + h; i < end; i += 8)
884-
lnnz += computeNnzBy8(a, i);
885-
return lnnz;
886-
}
887-
888-
private static int computeNnzBy8(final double[] a, final int i) {
889-
int lnnz = 0;
890-
lnnz += (a[i] != 0.0) ? 1 : 0;
891-
lnnz += (a[i+1] != 0.0) ? 1 : 0;
892-
lnnz += (a[i+2] != 0.0) ? 1 : 0;
893-
lnnz += (a[i+3] != 0.0) ? 1 : 0;
894-
lnnz += (a[i+4] != 0.0) ? 1 : 0;
895-
lnnz += (a[i+5] != 0.0) ? 1 : 0;
896-
lnnz += (a[i+6] != 0.0) ? 1 : 0;
897-
lnnz += (a[i+7] != 0.0) ? 1 : 0;
889+
for(int i = ai + rest; i < end; i += 8) {
890+
DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, i);
891+
lnnz += vLen-aVec.eq(0).trueCount();
892+
}
898893
return lnnz;
899894
}
900895

0 commit comments

Comments
 (0)