diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c22a662876c4a..7197dbde10aef 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2125,7 +2125,7 @@ class tinyBLAS_PPC { switch(m_rem) { case 1: mc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 2: mc = 2; @@ -2143,7 +2143,7 @@ class tinyBLAS_PPC { switch(n_rem) { case 1: nc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 2: nc = 2; @@ -2171,7 +2171,7 @@ class tinyBLAS_PPC { case 0x41: mc = 4; nc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x34: mc = 3; @@ -2191,7 +2191,7 @@ class tinyBLAS_PPC { case 0x31: mc = 3; nc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x24: mc = 2; @@ -2211,27 +2211,27 @@ class tinyBLAS_PPC { case 0x21: mc = 2; nc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x14: mc = 1; nc = 4; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x13: mc = 1; nc = 3; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x12: mc = 1; nc = 2; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; case 0x11: mc = 1; nc = 1; - gemm_small(m0, m, n0, n, mc, nc); + gemv(m0, m, n0, n, mc, nc); break; default: return; @@ -2285,6 +2285,53 @@ class tinyBLAS_PPC { } } + void gemv(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + //printf("In gemv, RM = %d, RN = %d \n", RM, RN); + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[4], vec_B[4]; + for (int l=0; l(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + vec_A[0] = (vec_t)vec_xl(0,a); + vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + } else if (RN == 1) { + packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + TB* b = const_cast(B+(jj)*ldb+l); + vec_B[0] = (vec_t)vec_xl(0,b); + vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); + } + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); + } + __builtin_mma_disassemble_acc(vec_C, &acc_0); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + } + } + } + } + template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; @@ -2370,9 +2417,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 assert(params->ith < params->nth); // only enable sgemm for prompt processing +#if !defined(__MMA__) if (n < 2) return false; - +#endif if (Ctype != GGML_TYPE_F32) return false;