Skip to content

Commit 6c53230

Browse files
committed
[MINOR] Default Matrix Multiplication specializations
This commit adds specializations for matrix multiplication with: 1. dense-sparse with sparse output 2. ultra sparse out dense dense in. 3. sparse out on sparse vector right side in. Furthermore, i modified the call stack to branch to native mm inside LibMatrixMult, to allow easy native support for CLA, by calling LibMatrixMult, instead of having to go though a MatrixBlock.
1 parent 804518a commit 6c53230

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
159159
* @return ret Matrix Block
160160
*/
161161
public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
162+
if(NativeHelper.isNativeLibraryLoaded())
163+
return LibMatrixNative.matrixMult(m1, m2, ret, k);
164+
else
165+
return matrixMult(m1, m2, ret, false, k);
166+
}
167+
168+
public static MatrixBlock matrixMultNonNative(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
162169
return matrixMult(m1, m2, ret, false, k);
163170
}
164171

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
140140
else
141141
LOG.warn("Was valid for native MM but native lib was not loaded");
142142

143-
return LibMatrixMult.matrixMult(m1, m2, ret, k);
143+
return LibMatrixMult.matrixMultNonNative(m1, m2, ret, k);
144144
}
145145

146146
public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) {

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4994,10 +4994,7 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
49944994
public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
49954995
checkAggregateBinaryOperations(m1, m2, op);
49964996
final int k = op.getNumThreads();
4997-
if(NativeHelper.isNativeLibraryLoaded())
4998-
return LibMatrixNative.matrixMult(m1, m2, ret, k);
4999-
else
5000-
return LibMatrixMult.matrixMult(m1, m2, ret, k);
4997+
return LibMatrixMult.matrixMult(m1, m2, ret, k);
50014998
}
50024999

50035000
protected void checkAggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op) {

0 commit comments

Comments
 (0)