Skip to content

Commit 7d1d1c1

Browse files
committed
llamafile : ppc64le GEMV forwarding for FP32.
This patch enables usage of MMA when one of the dimensions of the matrix(ie either M or N) is 1. This is useful in case of token generation where N < 2. The concept of 'GEMV Forwarding' is used where when one of the matrix has a single row/column, the elements are broadcasted, instead of using packing routine to prepack the matrix elements. This change results in 5% - 15% improvement in total speed(ie all tokens/total time), across various batch sizes. This is in comparision with the corresponding dot product implementation. The patch is tested with FP32 models of Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine. Signed-off-by: Amrita H S <[email protected]>
1 parent d7cfe1f commit 7d1d1c1

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,13 +2261,20 @@ class tinyBLAS_PPC {
22612261
__builtin_mma_xxsetaccz(&acc_0);
22622262
vec_t vec_A[4], vec_B[4];
22632263
for (int l=0; l<k; l+=4) {
2264-
if (RN >= 4 && RM == 1) {
2264+
if (RM == 1) {
22652265
TA* a = const_cast<TA*>(A+(ii)*lda+l);
2266-
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2266+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
22672267
vec_A[0] = (vec_t)vec_xl(0,a);
22682268
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
22692269
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
22702270
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2271+
} else if (RN == 1) {
2272+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2273+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2274+
vec_B[0] = (vec_t)vec_xl(0,b);
2275+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
2276+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
2277+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
22712278
} else {
22722279
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
22732280
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2371,8 +2378,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
23712378
assert(params->ith < params->nth);
23722379

23732380
// only enable sgemm for prompt processing
2381+
#if !defined(__MMA__)
23742382
if (n < 2)
23752383
return false;
2384+
#endif
23762385

23772386
if (Ctype != GGML_TYPE_F32)
23782387
return false;

0 commit comments

Comments
 (0)