@@ -1435,7 +1435,7 @@ class tinyBLAS_PPC {
14351435 switch (m_rem) {
14361436 case 1 :
14371437 mc = 1 ;
1438- gemm_small (m0, m, n0, n, mc, nc);
1438+ gemv (m0, m, n0, n, mc, nc);
14391439 break ;
14401440 case 2 :
14411441 mc = 2 ;
@@ -1453,7 +1453,7 @@ class tinyBLAS_PPC {
14531453 switch (n_rem) {
14541454 case 1 :
14551455 nc = 1 ;
1456- gemm_small (m0, m, n0, n, mc, nc);
1456+ gemv (m0, m, n0, n, mc, nc);
14571457 break ;
14581458 case 2 :
14591459 nc = 2 ;
@@ -1481,7 +1481,7 @@ class tinyBLAS_PPC {
14811481 case 0x41 :
14821482 mc = 4 ;
14831483 nc = 1 ;
1484- gemm_small (m0, m, n0, n, mc, nc);
1484+ gemv (m0, m, n0, n, mc, nc);
14851485 break ;
14861486 case 0x34 :
14871487 mc = 3 ;
@@ -1501,7 +1501,7 @@ class tinyBLAS_PPC {
15011501 case 0x31 :
15021502 mc = 3 ;
15031503 nc = 1 ;
1504- gemm_small (m0, m, n0, n, mc, nc);
1504+ gemv (m0, m, n0, n, mc, nc);
15051505 break ;
15061506 case 0x24 :
15071507 mc = 2 ;
@@ -1521,27 +1521,27 @@ class tinyBLAS_PPC {
15211521 case 0x21 :
15221522 mc = 2 ;
15231523 nc = 1 ;
1524- gemm_small (m0, m, n0, n, mc, nc);
1524+ gemv (m0, m, n0, n, mc, nc);
15251525 break ;
15261526 case 0x14 :
15271527 mc = 1 ;
15281528 nc = 4 ;
1529- gemm_small (m0, m, n0, n, mc, nc);
1529+ gemv (m0, m, n0, n, mc, nc);
15301530 break ;
15311531 case 0x13 :
15321532 mc = 1 ;
15331533 nc = 3 ;
1534- gemm_small (m0, m, n0, n, mc, nc);
1534+ gemv (m0, m, n0, n, mc, nc);
15351535 break ;
15361536 case 0x12 :
15371537 mc = 1 ;
15381538 nc = 2 ;
1539- gemm_small (m0, m, n0, n, mc, nc);
1539+ gemv (m0, m, n0, n, mc, nc);
15401540 break ;
15411541 case 0x11 :
15421542 mc = 1 ;
15431543 nc = 1 ;
1544- gemm_small (m0, m, n0, n, mc, nc);
1544+ gemv (m0, m, n0, n, mc, nc);
15451545 break ;
15461546 default :
15471547 return ;
@@ -1595,6 +1595,53 @@ class tinyBLAS_PPC {
15951595 }
15961596 }
15971597
1598+ void gemv (int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1599+ // printf("In gemv, RM = %d, RN = %d \n", RM, RN);
1600+ int64_t ytiles = (m - m0) / RM;
1601+ int64_t xtiles = (n - n0) / RN;
1602+ int64_t tiles = xtiles * ytiles;
1603+ int64_t duty = (tiles + nth - 1 ) / nth;
1604+ int64_t start = duty * ith;
1605+ int64_t end = start + duty;
1606+ if (end > tiles)
1607+ end = tiles;
1608+ for (int64_t job = start; job < end; ++job) {
1609+ int64_t ii = m0 + job / xtiles * RM;
1610+ int64_t jj = n0 + job % xtiles * RN;
1611+ vec_t vec_C[4 ];
1612+ acc_t acc_0;
1613+ __builtin_mma_xxsetaccz (&acc_0);
1614+ vec_t vec_A[4 ], vec_B[4 ];
1615+ for (int l=0 ; l<k; l+=4 ) {
1616+ if (RM == 1 ) {
1617+ TA* a = const_cast <TA*>(A+(ii)*lda+l);
1618+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
1619+ vec_A[0 ] = (vec_t )vec_xl (0 ,a);
1620+ vec_A[1 ] = (vec_t )vec_splats (*((TA*)&vec_A+1 ));
1621+ vec_A[2 ] = (vec_t )vec_splats (*((TA*)&vec_A+2 ));
1622+ vec_A[3 ] = (vec_t )vec_splats (*((TA*)&vec_A+3 ));
1623+ } else if (RN == 1 ) {
1624+ packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
1625+ TB* b = const_cast <TB*>(B+(jj)*ldb+l);
1626+ vec_B[0 ] = (vec_t )vec_xl (0 ,b);
1627+ vec_B[1 ] = (vec_t )vec_splats (*((TB*)&vec_B+1 ));
1628+ vec_B[2 ] = (vec_t )vec_splats (*((TB*)&vec_B+2 ));
1629+ vec_B[3 ] = (vec_t )vec_splats (*((TB*)&vec_B+3 ));
1630+ }
1631+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], vec_B[0 ]);
1632+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[1 ], vec_B[1 ]);
1633+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[2 ], vec_B[2 ]);
1634+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[3 ], vec_B[3 ]);
1635+ }
1636+ __builtin_mma_disassemble_acc (vec_C, &acc_0);
1637+ for (int I = 0 ; I < RM; I++) {
1638+ for (int J = 0 ; J < RN; J++) {
1639+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1640+ }
1641+ }
1642+ }
1643+ }
1644+
15981645 template <int RM, int RN>
15991646 NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
16001647 int64_t ytiles = (m - m0) / RM;
@@ -1680,9 +1727,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
16801727 assert (params->ith < params->nth );
16811728
16821729 // only enable sgemm for prompt processing
1730+ #if !defined(__MMA__)
16831731 if (n < 2 )
16841732 return false ;
1685-
1733+ # endif
16861734 if (Ctype != GGML_TYPE_F32)
16871735 return false ;
16881736
0 commit comments