@@ -2125,7 +2125,7 @@ class tinyBLAS_PPC {
21252125 switch (m_rem) {
21262126 case 1 :
21272127 mc = 1 ;
2128- gemm_small (m0, m, n0, n, mc, nc);
2128+ gemv (m0, m, n0, n, mc, nc);
21292129 break ;
21302130 case 2 :
21312131 mc = 2 ;
@@ -2143,7 +2143,7 @@ class tinyBLAS_PPC {
21432143 switch (n_rem) {
21442144 case 1 :
21452145 nc = 1 ;
2146- gemm_small (m0, m, n0, n, mc, nc);
2146+ gemv (m0, m, n0, n, mc, nc);
21472147 break ;
21482148 case 2 :
21492149 nc = 2 ;
@@ -2171,7 +2171,7 @@ class tinyBLAS_PPC {
21712171 case 0x41 :
21722172 mc = 4 ;
21732173 nc = 1 ;
2174- gemm_small (m0, m, n0, n, mc, nc);
2174+ gemv (m0, m, n0, n, mc, nc);
21752175 break ;
21762176 case 0x34 :
21772177 mc = 3 ;
@@ -2191,7 +2191,7 @@ class tinyBLAS_PPC {
21912191 case 0x31 :
21922192 mc = 3 ;
21932193 nc = 1 ;
2194- gemm_small (m0, m, n0, n, mc, nc);
2194+ gemv (m0, m, n0, n, mc, nc);
21952195 break ;
21962196 case 0x24 :
21972197 mc = 2 ;
@@ -2211,27 +2211,27 @@ class tinyBLAS_PPC {
22112211 case 0x21 :
22122212 mc = 2 ;
22132213 nc = 1 ;
2214- gemm_small (m0, m, n0, n, mc, nc);
2214+ gemv (m0, m, n0, n, mc, nc);
22152215 break ;
22162216 case 0x14 :
22172217 mc = 1 ;
22182218 nc = 4 ;
2219- gemm_small (m0, m, n0, n, mc, nc);
2219+ gemv (m0, m, n0, n, mc, nc);
22202220 break ;
22212221 case 0x13 :
22222222 mc = 1 ;
22232223 nc = 3 ;
2224- gemm_small (m0, m, n0, n, mc, nc);
2224+ gemv (m0, m, n0, n, mc, nc);
22252225 break ;
22262226 case 0x12 :
22272227 mc = 1 ;
22282228 nc = 2 ;
2229- gemm_small (m0, m, n0, n, mc, nc);
2229+ gemv (m0, m, n0, n, mc, nc);
22302230 break ;
22312231 case 0x11 :
22322232 mc = 1 ;
22332233 nc = 1 ;
2234- gemm_small (m0, m, n0, n, mc, nc);
2234+ gemv (m0, m, n0, n, mc, nc);
22352235 break ;
22362236 default :
22372237 return ;
@@ -2285,6 +2285,53 @@ class tinyBLAS_PPC {
22852285 }
22862286 }
22872287
2288+ void gemv (int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2289+ // printf("In gemv, RM = %d, RN = %d \n", RM, RN);
2290+ int64_t ytiles = (m - m0) / RM;
2291+ int64_t xtiles = (n - n0) / RN;
2292+ int64_t tiles = xtiles * ytiles;
2293+ int64_t duty = (tiles + nth - 1 ) / nth;
2294+ int64_t start = duty * ith;
2295+ int64_t end = start + duty;
2296+ if (end > tiles)
2297+ end = tiles;
2298+ for (int64_t job = start; job < end; ++job) {
2299+ int64_t ii = m0 + job / xtiles * RM;
2300+ int64_t jj = n0 + job % xtiles * RN;
2301+ vec_t vec_C[4 ];
2302+ acc_t acc_0;
2303+ __builtin_mma_xxsetaccz (&acc_0);
2304+ vec_t vec_A[4 ], vec_B[4 ];
2305+ for (int l=0 ; l<k; l+=4 ) {
2306+ if (RM == 1 ) {
2307+ TA* a = const_cast <TA*>(A+(ii)*lda+l);
2308+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
2309+ vec_A[0 ] = (vec_t )vec_xl (0 ,a);
2310+ vec_A[1 ] = (vec_t )vec_splats (*((TA*)&vec_A+1 ));
2311+ vec_A[2 ] = (vec_t )vec_splats (*((TA*)&vec_A+2 ));
2312+ vec_A[3 ] = (vec_t )vec_splats (*((TA*)&vec_A+3 ));
2313+ } else if (RN == 1 ) {
2314+ packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
2315+ TB* b = const_cast <TB*>(B+(jj)*ldb+l);
2316+ vec_B[0 ] = (vec_t )vec_xl (0 ,b);
2317+ vec_B[1 ] = (vec_t )vec_splats (*((TB*)&vec_B+1 ));
2318+ vec_B[2 ] = (vec_t )vec_splats (*((TB*)&vec_B+2 ));
2319+ vec_B[3 ] = (vec_t )vec_splats (*((TB*)&vec_B+3 ));
2320+ }
2321+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], vec_B[0 ]);
2322+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[1 ], vec_B[1 ]);
2323+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[2 ], vec_B[2 ]);
2324+ __builtin_mma_xvf32gerpp (&acc_0, vec_A[3 ], vec_B[3 ]);
2325+ }
2326+ __builtin_mma_disassemble_acc (vec_C, &acc_0);
2327+ for (int I = 0 ; I < RM; I++) {
2328+ for (int J = 0 ; J < RN; J++) {
2329+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2330+ }
2331+ }
2332+ }
2333+ }
2334+
22882335 template <int RM, int RN>
22892336 NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
22902337 int64_t ytiles = (m - m0) / RM;
@@ -2370,9 +2417,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
23702417 assert (params->ith < params->nth );
23712418
23722419 // only enable sgemm for prompt processing
2420+ #if !defined(__MMA__)
23732421 if (n < 2 )
23742422 return false ;
2375-
2423+ # endif
23762424 if (Ctype != GGML_TYPE_F32)
23772425 return false ;
23782426
0 commit comments