@@ -2125,7 +2125,7 @@ class tinyBLAS_PPC {
21252125            switch (m_rem) {
21262126                case  1 :
21272127                    mc = 1 ;
2128-                     gemm_small (m0, m, n0, n, mc, nc);
2128+                     gemv (m0, m, n0, n, mc, nc);
21292129                    break ;
21302130                case  2 :
21312131                    mc = 2 ;
@@ -2143,7 +2143,7 @@ class tinyBLAS_PPC {
21432143            switch (n_rem) {
21442144                case  1 :
21452145                    nc = 1 ;
2146-                     gemm_small (m0, m, n0, n, mc, nc);
2146+                     gemv (m0, m, n0, n, mc, nc);
21472147                    break ;
21482148                case  2 :
21492149                    nc = 2 ;
@@ -2171,7 +2171,7 @@ class tinyBLAS_PPC {
21712171                case  0x41 :
21722172                    mc = 4 ;
21732173                    nc = 1 ;
2174-                     gemm_small (m0, m, n0, n, mc, nc);
2174+                     gemv (m0, m, n0, n, mc, nc);
21752175                    break ;
21762176                case  0x34 :
21772177                    mc = 3 ;
@@ -2191,7 +2191,7 @@ class tinyBLAS_PPC {
21912191                case  0x31 :
21922192                    mc = 3 ;
21932193                    nc = 1 ;
2194-                     gemm_small (m0, m, n0, n, mc, nc);
2194+                     gemv (m0, m, n0, n, mc, nc);
21952195                    break ;
21962196                case  0x24 :
21972197                    mc = 2 ;
@@ -2211,27 +2211,27 @@ class tinyBLAS_PPC {
22112211                case  0x21 :
22122212                    mc = 2 ;
22132213                    nc = 1 ;
2214-                     gemm_small (m0, m, n0, n, mc, nc);
2214+                     gemv (m0, m, n0, n, mc, nc);
22152215                    break ;
22162216                case  0x14 :
22172217                    mc = 1 ;
22182218                    nc = 4 ;
2219-                     gemm_small (m0, m, n0, n, mc, nc);
2219+                     gemv (m0, m, n0, n, mc, nc);
22202220                    break ;
22212221                case  0x13 :
22222222                    mc = 1 ;
22232223                    nc = 3 ;
2224-                     gemm_small (m0, m, n0, n, mc, nc);
2224+                     gemv (m0, m, n0, n, mc, nc);
22252225                    break ;
22262226                case  0x12 :
22272227                    mc = 1 ;
22282228                    nc = 2 ;
2229-                     gemm_small (m0, m, n0, n, mc, nc);
2229+                     gemv (m0, m, n0, n, mc, nc);
22302230                    break ;
22312231                case  0x11 :
22322232                    mc = 1 ;
22332233                    nc = 1 ;
2234-                     gemm_small (m0, m, n0, n, mc, nc);
2234+                     gemv (m0, m, n0, n, mc, nc);
22352235                    break ;
22362236                default :
22372237                    return ;
@@ -2285,6 +2285,53 @@ class tinyBLAS_PPC {
22852285       }
22862286    }
22872287
2288+      void  gemv (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n, int  RM, int  RN) {
2289+ 	// printf("In gemv, RM = %d, RN = %d \n", RM, RN);     
2290+         int64_t  ytiles = (m - m0) / RM;
2291+         int64_t  xtiles = (n - n0) / RN;
2292+         int64_t  tiles = xtiles * ytiles;
2293+         int64_t  duty = (tiles + nth - 1 ) / nth;
2294+         int64_t  start = duty * ith;
2295+         int64_t  end = start + duty;
2296+         if  (end > tiles)
2297+             end = tiles;
2298+         for  (int64_t  job = start; job < end; ++job) {
2299+             int64_t  ii = m0 + job / xtiles * RM;
2300+             int64_t  jj = n0 + job % xtiles * RN;
2301+             vec_t  vec_C[4 ];
2302+             acc_t  acc_0;
2303+             __builtin_mma_xxsetaccz (&acc_0);
2304+             vec_t  vec_A[4 ], vec_B[4 ];
2305+             for  (int  l=0 ; l<k; l+=4 ) {
2306+                 if  (RM == 1 ) {
2307+                     TA* a = const_cast <TA*>(A+(ii)*lda+l);
2308+                     packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
2309+                     vec_A[0 ] = (vec_t )vec_xl (0 ,a);
2310+                     vec_A[1 ] = (vec_t )vec_splats (*((TA*)&vec_A+1 ));
2311+                     vec_A[2 ] = (vec_t )vec_splats (*((TA*)&vec_A+2 ));
2312+                     vec_A[3 ] = (vec_t )vec_splats (*((TA*)&vec_A+3 ));
2313+                 } else  if  (RN == 1 ) {
2314+                     packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
2315+                     TB* b = const_cast <TB*>(B+(jj)*ldb+l);
2316+                     vec_B[0 ] = (vec_t )vec_xl (0 ,b);
2317+                     vec_B[1 ] = (vec_t )vec_splats (*((TB*)&vec_B+1 ));
2318+                     vec_B[2 ] = (vec_t )vec_splats (*((TB*)&vec_B+2 ));
2319+                     vec_B[3 ] = (vec_t )vec_splats (*((TB*)&vec_B+3 ));
2320+                 }
2321+                 __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], vec_B[0 ]);
2322+                 __builtin_mma_xvf32gerpp (&acc_0, vec_A[1 ], vec_B[1 ]);
2323+                 __builtin_mma_xvf32gerpp (&acc_0, vec_A[2 ], vec_B[2 ]);
2324+                 __builtin_mma_xvf32gerpp (&acc_0, vec_A[3 ], vec_B[3 ]);
2325+             }
2326+             __builtin_mma_disassemble_acc (vec_C, &acc_0);
2327+             for  (int  I = 0 ; I < RM; I++) {
2328+                 for  (int  J = 0 ; J < RN; J++) {
2329+                     *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2330+                 }
2331+             }
2332+        }
2333+     }
2334+     
22882335    template  <int  RM, int  RN>
22892336    NOINLINE void  gemm (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n) {
22902337        int64_t  ytiles = (m - m0) / RM;
@@ -2370,9 +2417,11 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
23702417    assert (params->ith  < params->nth );
23712418
23722419    //  only enable sgemm for prompt processing
2420+ /* #if !defined(__MMA__)
23732421    if (n < 2) 
23742422        return false; 
2375- 
2423+ #endif 
2424+ */ 
23762425    if  (Ctype != GGML_TYPE_F32)
23772426        return  false ;
23782427
@@ -2401,7 +2450,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
24012450            (const  float  *)B, ldb,
24022451            (float  *)C, ldc};
24032452        return  tb.matmul (m, n);
2404- #elif  defined(__MMA__)
2453+ /* #elif defined(__MMA__)
24052454        if (k % 8) 
24062455            return false; 
24072456        tinyBLAS_PPC<float, float, float> tb{ 
@@ -2411,6 +2460,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
24112460            params->ith, params->nth}; 
24122461        tb.matmul(m, n); 
24132462        return true; 
2463+ */ 
24142464#else 
24152465        return  false ;
24162466#endif 
0 commit comments