Skip to content

Commit d8abafe

Browse files
committed
GEMV implementation
Signed-off-by: Amrita H S <[email protected]>
1 parent 0827b2c commit d8abafe

File tree

1 file changed

+58
-10
lines changed

1 file changed

+58
-10
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,7 @@ class tinyBLAS_PPC {
14351435
switch(m_rem) {
14361436
case 1:
14371437
mc = 1;
1438-
gemm_small(m0, m, n0, n, mc, nc);
1438+
gemv(m0, m, n0, n, mc, nc);
14391439
break;
14401440
case 2:
14411441
mc = 2;
@@ -1453,7 +1453,7 @@ class tinyBLAS_PPC {
14531453
switch(n_rem) {
14541454
case 1:
14551455
nc = 1;
1456-
gemm_small(m0, m, n0, n, mc, nc);
1456+
gemv(m0, m, n0, n, mc, nc);
14571457
break;
14581458
case 2:
14591459
nc = 2;
@@ -1481,7 +1481,7 @@ class tinyBLAS_PPC {
14811481
case 0x41:
14821482
mc = 4;
14831483
nc = 1;
1484-
gemm_small(m0, m, n0, n, mc, nc);
1484+
gemv(m0, m, n0, n, mc, nc);
14851485
break;
14861486
case 0x34:
14871487
mc = 3;
@@ -1501,7 +1501,7 @@ class tinyBLAS_PPC {
15011501
case 0x31:
15021502
mc = 3;
15031503
nc = 1;
1504-
gemm_small(m0, m, n0, n, mc, nc);
1504+
gemv(m0, m, n0, n, mc, nc);
15051505
break;
15061506
case 0x24:
15071507
mc = 2;
@@ -1521,27 +1521,27 @@ class tinyBLAS_PPC {
15211521
case 0x21:
15221522
mc = 2;
15231523
nc = 1;
1524-
gemm_small(m0, m, n0, n, mc, nc);
1524+
gemv(m0, m, n0, n, mc, nc);
15251525
break;
15261526
case 0x14:
15271527
mc = 1;
15281528
nc = 4;
1529-
gemm_small(m0, m, n0, n, mc, nc);
1529+
gemv(m0, m, n0, n, mc, nc);
15301530
break;
15311531
case 0x13:
15321532
mc = 1;
15331533
nc = 3;
1534-
gemm_small(m0, m, n0, n, mc, nc);
1534+
gemv(m0, m, n0, n, mc, nc);
15351535
break;
15361536
case 0x12:
15371537
mc = 1;
15381538
nc = 2;
1539-
gemm_small(m0, m, n0, n, mc, nc);
1539+
gemv(m0, m, n0, n, mc, nc);
15401540
break;
15411541
case 0x11:
15421542
mc = 1;
15431543
nc = 1;
1544-
gemm_small(m0, m, n0, n, mc, nc);
1544+
gemv(m0, m, n0, n, mc, nc);
15451545
break;
15461546
default:
15471547
return;
@@ -1595,6 +1595,53 @@ class tinyBLAS_PPC {
15951595
}
15961596
}
15971597

1598+
void gemv(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1599+
//printf("In gemv, RM = %d, RN = %d \n", RM, RN);
1600+
int64_t ytiles = (m - m0) / RM;
1601+
int64_t xtiles = (n - n0) / RN;
1602+
int64_t tiles = xtiles * ytiles;
1603+
int64_t duty = (tiles + nth - 1) / nth;
1604+
int64_t start = duty * ith;
1605+
int64_t end = start + duty;
1606+
if (end > tiles)
1607+
end = tiles;
1608+
for (int64_t job = start; job < end; ++job) {
1609+
int64_t ii = m0 + job / xtiles * RM;
1610+
int64_t jj = n0 + job % xtiles * RN;
1611+
vec_t vec_C[4];
1612+
acc_t acc_0;
1613+
__builtin_mma_xxsetaccz(&acc_0);
1614+
vec_t vec_A[4], vec_B[4];
1615+
for (int l=0; l<k; l+=4) {
1616+
if (RM == 1) {
1617+
TA* a = const_cast<TA*>(A+(ii)*lda+l);
1618+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
1619+
vec_A[0] = (vec_t)vec_xl(0,a);
1620+
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
1621+
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
1622+
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
1623+
} else if (RN == 1) {
1624+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
1625+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
1626+
vec_B[0] = (vec_t)vec_xl(0,b);
1627+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
1628+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
1629+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
1630+
}
1631+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1632+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1633+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1634+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1635+
}
1636+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
1637+
for (int I = 0; I < RM; I++) {
1638+
for (int J = 0; J < RN; J++) {
1639+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1640+
}
1641+
}
1642+
}
1643+
}
1644+
15981645
template <int RM, int RN>
15991646
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
16001647
int64_t ytiles = (m - m0) / RM;
@@ -1680,9 +1727,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
16801727
assert(params->ith < params->nth);
16811728

16821729
// only enable sgemm for prompt processing
1730+
#if !defined(__MMA__)
16831731
if (n < 2)
16841732
return false;
1685-
1733+
#endif
16861734
if (Ctype != GGML_TYPE_F32)
16871735
return false;
16881736

0 commit comments

Comments
 (0)