Skip to content

Commit fdb32c3

Browse files
committed
Changes
- implemented the new method matrixMultDenseDenseMM, where using different iterations, when dealing with transposes, trying to achieve a speedup - added one test class for the correctness of the implementation - added on test class to check, wether the running time has improved
1 parent c7300f3 commit fdb32c3

File tree

3 files changed

+303
-0
lines changed

3 files changed

+303
-0
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,146 @@ public static void matrixMultWuMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV
10261026
// optimized matrix mult implementation //
10271027
//////////////////////////////////////////
10281028

1029+
public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, DenseBlock c, boolean transA, boolean transB, int n, int cd, int rl, int ru, int cl, int cu) {
1030+
// C = A %*% B
1031+
if (!transA && !transB) {
1032+
matrixMultDenseDenseMM(a, b, c, n, cd, rl, ru, cl, cu);
1033+
return;
1034+
}
1035+
// C = t(A) %*% B
1036+
if (transA && !transB) {
1037+
multDenseDenseTransA(a, b, c, n, cd, rl, ru, cl, cu);
1038+
return;
1039+
}
1040+
// C = A %*% t(B)
1041+
if (!transA && transB) {
1042+
multDenseDenseTransB(a, b, c, n, cd, rl, ru, cl, cu);
1043+
return;
1044+
}
1045+
// C = t(A) %*% t(B)
1046+
if (transA && transB) {
1047+
multDenseDenseTransATransB(a, b, c, n, cd, rl, ru, cl, cu);
1048+
return;
1049+
}
1050+
}
1051+
1052+
private static void multDenseDenseTransA(DenseBlock a, DenseBlock b, DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
1053+
final int blocksizeI = 32;
1054+
final int blocksizeK = 24;
1055+
1056+
for (int bi = rl; bi < ru; bi += blocksizeI) {
1057+
int bimin = Math.min(ru, bi + blocksizeI);
1058+
for (int bk = 0; bk < cd; bk += blocksizeK) {
1059+
int bkmin = Math.min(cd, bk + blocksizeK);
1060+
1061+
int k = bk;
1062+
for (; k < bkmin - 3; k += 4) {
1063+
if (b.isContiguous()) {
1064+
double[] bvals = b.values(0);
1065+
int bix0 = b.pos(k, cl), bix1 = b.pos(k+1, cl);
1066+
int bix2 = b.pos(k+2, cl), bix3 = b.pos(k+3, cl);
1067+
1068+
for (int i = bi; i < bimin; i++) {
1069+
double[] cvals = c.values(i);
1070+
int cix = c.pos(i, cl);
1071+
double val0 = a.values(k)[a.pos(k) + i];
1072+
double val1 = a.values(k+1)[a.pos(k+1) + i];
1073+
double val2 = a.values(k+2)[a.pos(k+2) + i];
1074+
double val3 = a.values(k+3)[a.pos(k+3) + i];
1075+
1076+
vectMultiplyAdd4(val0, val1, val2, val3,
1077+
bvals, cvals,
1078+
bix0, bix1, bix2, bix3, cix, cu - cl);
1079+
}
1080+
} else {
1081+
for (int i = bi; i < bimin; i++) {
1082+
double[] cvals = c.values(i);
1083+
int cix = c.pos(i, cl);
1084+
double val0 = a.values(k)[a.pos(k) + i];
1085+
if(val0!=0) vectMultiplyAdd(val0, b.values(k), cvals, b.pos(k, cl), cix, cu - cl);
1086+
double val1 = a.values(k+1)[a.pos(k+1) + i];
1087+
if(val1!=0) vectMultiplyAdd(val1, b.values(k+1), cvals, b.pos(k+1, cl), cix, cu - cl);
1088+
double val2 = a.values(k+2)[a.pos(k+2) + i];
1089+
if(val2!=0) vectMultiplyAdd(val2, b.values(k+2), cvals, b.pos(k+2, cl), cix, cu - cl);
1090+
double val3 = a.values(k+3)[a.pos(k+3) + i];
1091+
if(val3!=0) vectMultiplyAdd(val3, b.values(k+3), cvals, b.pos(k+3, cl), cix, cu - cl);
1092+
}
1093+
}
1094+
}
1095+
for (; k < bkmin; k++) {
1096+
double[] bvals = b.values(k);
1097+
int bix = b.pos(k, cl);
1098+
double[] avals = a.values(k);
1099+
int apos = a.pos(k);
1100+
for (int i = bi; i < bimin; i++) {
1101+
double val = avals[apos + i];
1102+
if (val != 0) {
1103+
vectMultiplyAdd(val, bvals, c.values(i), bix, c.pos(i, cl), cu - cl);
1104+
}
1105+
}
1106+
}
1107+
}
1108+
}
1109+
}
1110+
1111+
private static void multDenseDenseTransB(DenseBlock a, DenseBlock b, DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
1112+
final int blocksizeK = 24;
1113+
double[] bufB = new double[blocksizeK * (cu - cl)];
1114+
1115+
for (int bk = 0; bk < cd; bk += blocksizeK) {
1116+
int bkmin = Math.min(cd, bk + blocksizeK);
1117+
int bklen = bkmin - bk;
1118+
1119+
1120+
for (int j = cl; j < cu; j++) {
1121+
double[] bvals = b.values(j);
1122+
int bpos = b.pos(j);
1123+
1124+
for (int k = 0; k < bklen; k++) {
1125+
bufB[k * (cu-cl) + (j-cl)] = bvals[bpos + bk + k];
1126+
}
1127+
}
1128+
1129+
for (int i = rl; i < ru; i++) {
1130+
double[] avals = a.values(i);
1131+
int apos = a.pos(i);
1132+
double[] cvals = c.values(i);
1133+
int cix = c.pos(i, cl);
1134+
1135+
for (int k = 0; k < bklen; k++) {
1136+
double val = avals[apos + bk + k];
1137+
if (val != 0) {
1138+
int bufIx = k * (cu-cl);
1139+
vectMultiplyAdd(val, bufB, cvals, bufIx, cix, cu - cl);
1140+
}
1141+
}
1142+
}
1143+
}
1144+
}
1145+
1146+
private static void multDenseDenseTransATransB(DenseBlock a, DenseBlock b, DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
1147+
double[] d_row = new double[ru];
1148+
1149+
for (int j = cl; j < cu; j++) {
1150+
java.util.Arrays.fill(d_row, rl, ru, 0);
1151+
1152+
for (int k = 0; k < cd; k++) {
1153+
double valB = b.get(j, k);
1154+
if (valB != 0) {
1155+
double[] avals = a.values(k);
1156+
int apos = a.pos(k);
1157+
vectMultiplyAdd(valB, avals, d_row, apos + rl, rl, ru - rl);
1158+
}
1159+
}
1160+
1161+
for (int i = rl; i < ru; i++) {
1162+
double[] cvals = c.values(i);
1163+
int cix = c.pos(i);
1164+
cvals[cix + j] += d_row[i];
1165+
}
1166+
}
1167+
}
1168+
10291169
private static void matrixMultDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean tm2, boolean pm2, int rl, int ru, int cl, int cu) {
10301170
DenseBlock a = m1.getDenseBlock();
10311171
DenseBlock b = m2.getDenseBlock();
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package org.apache.sysds.test.component.matrixmult;
2+
3+
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
4+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
5+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
6+
import org.junit.Assert;
7+
import org.junit.Test;
8+
9+
public class MatrixMultTransposedPerformanceTest {
10+
private final int m = 500;
11+
private final int n = 500;
12+
private final int k = 500;
13+
14+
15+
@Test
16+
public void testPerf_1_NoTransA_TransB() {
17+
System.out.println("Case: C = A %*% t(B)");
18+
runTest(false, true);
19+
System.out.println();
20+
}
21+
22+
@Test
23+
public void testPerf_2_TransA_NoTransB() {
24+
System.out.println("Case: C = t(A) %*% B");
25+
runTest(true, false);
26+
System.out.println();
27+
}
28+
29+
@Test
30+
public void testPerf_3_TransA_TransB() {
31+
System.out.println("Case: C = t(A) %*% t(B)");
32+
runTest(true, true);
33+
}
34+
35+
private void runTest(boolean tA, boolean tB) {
36+
int REP = 50;
37+
38+
// setup Dimensions
39+
int rowsA = tA ? k : m;
40+
int colsA = tA ? m : k;
41+
int rowsB = tB ? n : k;
42+
int colsB = tB ? k : n;
43+
44+
MatrixBlock A = MatrixBlock.randOperations(rowsA, colsA, 1.0, -1, 1, "uniform", 7);
45+
MatrixBlock B = MatrixBlock.randOperations(rowsB, colsB, 1.0, -1, 1, "uniform", 3);
46+
MatrixBlock C = new MatrixBlock(m, n, false);
47+
C.allocateDenseBlock();
48+
49+
50+
for(int i=0; i<50; i++) {
51+
runOldMethod(A, B, tA, tB);
52+
runNewKernel(A, B, C, tA, tB);
53+
}
54+
55+
56+
long startTimeOld = System.nanoTime();
57+
for(int i = 0; i < REP; i++) {
58+
runOldMethod(A, B, tA, tB);
59+
}
60+
double avgTimeOld = (System.nanoTime() - startTimeOld) / 1e6 / REP;
61+
62+
63+
double startTimeNew = System.nanoTime();
64+
for(int i = 0; i < REP; i++) {
65+
runNewKernel(A, B, C, tA, tB);
66+
}
67+
double avgTimeNew = (System.nanoTime() - startTimeNew) / 1e6 / REP;
68+
69+
System.out.printf("Old Method: %.3f ms | New Kernel: %.3f ms", avgTimeOld, avgTimeNew);
70+
71+
Assert.assertTrue(avgTimeNew < avgTimeOld);
72+
}
73+
74+
private void runNewKernel(MatrixBlock A, MatrixBlock B, MatrixBlock C, boolean tA, boolean tB) {
75+
C.reset();
76+
77+
LibMatrixMult.matrixMultDenseDenseMM(A.getDenseBlock(), B.getDenseBlock(), C.getDenseBlock(), tA, tB, m, k, 0, m, 0, n);
78+
}
79+
80+
private void runOldMethod(MatrixBlock A, MatrixBlock B, boolean tA, boolean tB) {
81+
// do transpose if needed
82+
MatrixBlock A_in = tA ? LibMatrixReorg.transpose(A) : A;
83+
MatrixBlock B_in = tB ? LibMatrixReorg.transpose(B) : B;
84+
85+
MatrixBlock C = new MatrixBlock(m, n, false);
86+
C.allocateDenseBlock();
87+
88+
LibMatrixMult.matrixMultDenseDenseMM(A_in.getDenseBlock(), B_in.getDenseBlock(), C.getDenseBlock(), false,
89+
false, m, k, 0, m, 0, n);
90+
}
91+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package org.apache.sysds.test.component.matrixmult;
2+
3+
import org.apache.sysds.runtime.data.DenseBlock;
4+
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
5+
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
6+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
7+
import org.apache.sysds.test.TestUtils;
8+
import org.junit.Test;
9+
10+
import java.util.Random;
11+
12+
public class MatrixMultTransposedTest {
13+
14+
// run multiple random scenarios
15+
@Test
16+
public void testCase_noTransA_TransB() {
17+
for(int i=0; i<10; i++) {
18+
runTest(false, true);
19+
}
20+
}
21+
22+
@Test
23+
public void testCase_TransA_NoTransB() {
24+
for(int i=0; i<10; i++) {
25+
runTest(true, false);
26+
}
27+
}
28+
29+
@Test
30+
public void testCase_TransA_TransB() {
31+
for(int i=0; i<10; i++) {
32+
runTest(true, true);
33+
}
34+
}
35+
36+
private void runTest(boolean tA, boolean tB) {
37+
Random rand = new Random();
38+
39+
// generate random dimensions between 1 and 300
40+
int m = rand.nextInt(300) + 1;
41+
int n = rand.nextInt(300) + 1;
42+
int k = rand.nextInt(300) + 1;
43+
44+
45+
int rowsA = tA ? k : m;
46+
int colsA = tA ? m : k;
47+
int rowsB = tB ? n : k;
48+
int colsB = tB ? k : n;
49+
50+
MatrixBlock ma = MatrixBlock.randOperations(rowsA, colsA, 1.0, -1, 1, "uniform", 7);
51+
MatrixBlock mb = MatrixBlock.randOperations(rowsB, colsB, 1.0, -1, 1, "uniform", 3);
52+
53+
MatrixBlock mc = new MatrixBlock(m, n, false);
54+
mc.allocateDenseBlock();
55+
56+
DenseBlock a = ma.getDenseBlock();
57+
DenseBlock b = mb.getDenseBlock();
58+
DenseBlock c = mc.getDenseBlock();
59+
60+
LibMatrixMult.matrixMultDenseDenseMM(a, b, c, tA, tB, n, k, 0, m, 0, n);
61+
62+
mc.recomputeNonZeros();
63+
64+
// calc true result with existing methods
65+
MatrixBlock ma_in = tA ? LibMatrixReorg.transpose(ma) : ma;
66+
MatrixBlock mb_in = tB ? LibMatrixReorg.transpose(mb) : mb;
67+
MatrixBlock expected = LibMatrixMult.matrixMult(ma_in, mb_in);
68+
69+
// compare results
70+
TestUtils.compareMatrices(expected, mc, 1e-8);
71+
}
72+
}

0 commit comments

Comments
 (0)