Skip to content

Commit 4418c26

Browse files
committed
GEMV implementation
Signed-off-by: Amrita H S <[email protected]>
1 parent 5845661 commit 4418c26

File tree

1 file changed

+61
-11
lines changed

1 file changed

+61
-11
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,7 @@ class tinyBLAS_PPC {
21252125
switch(m_rem) {
21262126
case 1:
21272127
mc = 1;
2128-
gemm_small(m0, m, n0, n, mc, nc);
2128+
gemv(m0, m, n0, n, mc, nc);
21292129
break;
21302130
case 2:
21312131
mc = 2;
@@ -2143,7 +2143,7 @@ class tinyBLAS_PPC {
21432143
switch(n_rem) {
21442144
case 1:
21452145
nc = 1;
2146-
gemm_small(m0, m, n0, n, mc, nc);
2146+
gemv(m0, m, n0, n, mc, nc);
21472147
break;
21482148
case 2:
21492149
nc = 2;
@@ -2171,7 +2171,7 @@ class tinyBLAS_PPC {
21712171
case 0x41:
21722172
mc = 4;
21732173
nc = 1;
2174-
gemm_small(m0, m, n0, n, mc, nc);
2174+
gemv(m0, m, n0, n, mc, nc);
21752175
break;
21762176
case 0x34:
21772177
mc = 3;
@@ -2191,7 +2191,7 @@ class tinyBLAS_PPC {
21912191
case 0x31:
21922192
mc = 3;
21932193
nc = 1;
2194-
gemm_small(m0, m, n0, n, mc, nc);
2194+
gemv(m0, m, n0, n, mc, nc);
21952195
break;
21962196
case 0x24:
21972197
mc = 2;
@@ -2211,27 +2211,27 @@ class tinyBLAS_PPC {
22112211
case 0x21:
22122212
mc = 2;
22132213
nc = 1;
2214-
gemm_small(m0, m, n0, n, mc, nc);
2214+
gemv(m0, m, n0, n, mc, nc);
22152215
break;
22162216
case 0x14:
22172217
mc = 1;
22182218
nc = 4;
2219-
gemm_small(m0, m, n0, n, mc, nc);
2219+
gemv(m0, m, n0, n, mc, nc);
22202220
break;
22212221
case 0x13:
22222222
mc = 1;
22232223
nc = 3;
2224-
gemm_small(m0, m, n0, n, mc, nc);
2224+
gemv(m0, m, n0, n, mc, nc);
22252225
break;
22262226
case 0x12:
22272227
mc = 1;
22282228
nc = 2;
2229-
gemm_small(m0, m, n0, n, mc, nc);
2229+
gemv(m0, m, n0, n, mc, nc);
22302230
break;
22312231
case 0x11:
22322232
mc = 1;
22332233
nc = 1;
2234-
gemm_small(m0, m, n0, n, mc, nc);
2234+
gemv(m0, m, n0, n, mc, nc);
22352235
break;
22362236
default:
22372237
return;
@@ -2285,6 +2285,53 @@ class tinyBLAS_PPC {
22852285
}
22862286
}
22872287

2288+
void gemv(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2289+
//printf("In gemv, RM = %d, RN = %d \n", RM, RN);
2290+
int64_t ytiles = (m - m0) / RM;
2291+
int64_t xtiles = (n - n0) / RN;
2292+
int64_t tiles = xtiles * ytiles;
2293+
int64_t duty = (tiles + nth - 1) / nth;
2294+
int64_t start = duty * ith;
2295+
int64_t end = start + duty;
2296+
if (end > tiles)
2297+
end = tiles;
2298+
for (int64_t job = start; job < end; ++job) {
2299+
int64_t ii = m0 + job / xtiles * RM;
2300+
int64_t jj = n0 + job % xtiles * RN;
2301+
vec_t vec_C[4];
2302+
acc_t acc_0;
2303+
__builtin_mma_xxsetaccz(&acc_0);
2304+
vec_t vec_A[4], vec_B[4];
2305+
for (int l=0; l<k; l+=4) {
2306+
if (RM == 1) {
2307+
TA* a = const_cast<TA*>(A+(ii)*lda+l);
2308+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2309+
vec_A[0] = (vec_t)vec_xl(0,a);
2310+
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2311+
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2312+
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2313+
} else if (RN == 1) {
2314+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2315+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2316+
vec_B[0] = (vec_t)vec_xl(0,b);
2317+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
2318+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
2319+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
2320+
}
2321+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2322+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2323+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2324+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2325+
}
2326+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
2327+
for (int I = 0; I < RM; I++) {
2328+
for (int J = 0; J < RN; J++) {
2329+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2330+
}
2331+
}
2332+
}
2333+
}
2334+
22882335
template <int RM, int RN>
22892336
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
22902337
int64_t ytiles = (m - m0) / RM;
@@ -2370,9 +2417,11 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
23702417
assert(params->ith < params->nth);
23712418

23722419
// only enable sgemm for prompt processing
2420+
/*#if !defined(__MMA__)
23732421
if (n < 2)
23742422
return false;
2375-
2423+
#endif
2424+
*/
23762425
if (Ctype != GGML_TYPE_F32)
23772426
return false;
23782427

@@ -2401,7 +2450,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
24012450
(const float *)B, ldb,
24022451
(float *)C, ldc};
24032452
return tb.matmul(m, n);
2404-
#elif defined(__MMA__)
2453+
/*#elif defined(__MMA__)
24052454
if (k % 8)
24062455
return false;
24072456
tinyBLAS_PPC<float, float, float> tb{
@@ -2411,6 +2460,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
24112460
params->ith, params->nth};
24122461
tb.matmul(m, n);
24132462
return true;
2463+
*/
24142464
#else
24152465
return false;
24162466
#endif

0 commit comments

Comments
 (0)