Skip to content

Commit e013b4d

Browse files
committed
llamafile : ppc64le GEMV forwarding for FP32.
This patch enables usage of MMA when one of the dimensions of the matrix(ie either M or N) is 1. This is useful in case of token generation where N < 2. The concept of 'GEMV Forwarding' is used where when one of the matrix has a single row/column, the elements are broadcasted, instead of using packing routine to prepack the matrix elements. This change results in 5% - 15% improvement in total speed(ie all tokens/total time), across various batch sizes. This is in comparision with the corresponding dot product implementation. The patch is tested with FP32 models of Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine. Signed-off-by: Amrita H S <[email protected]>
1 parent d7cfe1f commit e013b4d

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,13 +2261,25 @@ class tinyBLAS_PPC {
22612261
__builtin_mma_xxsetaccz(&acc_0);
22622262
vec_t vec_A[4], vec_B[4];
22632263
for (int l=0; l<k; l+=4) {
2264-
if (RN >= 4 && RM == 1) {
2264+
/* 'GEMV Forwarding' concept is used in first two conditional loops.
2265+
* when one of the matrix has a single row/column, the elements are
2266+
* broadcasted, instead of using packing routine to prepack the
2267+
* matrix elements.
2268+
*/
2269+
if (RM == 1) {
22652270
TA* a = const_cast<TA*>(A+(ii)*lda+l);
2266-
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2271+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
22672272
vec_A[0] = (vec_t)vec_xl(0,a);
22682273
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
22692274
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
22702275
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2276+
} else if (RN == 1) {
2277+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2278+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2279+
vec_B[0] = (vec_t)vec_xl(0,b);
2280+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
2281+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
2282+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
22712283
} else {
22722284
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
22732285
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2371,8 +2383,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
23712383
assert(params->ith < params->nth);
23722384

23732385
// only enable sgemm for prompt processing
2386+
#if !defined(__MMA__)
23742387
if (n < 2)
23752388
return false;
2389+
#endif
23762390

23772391
if (Ctype != GGML_TYPE_F32)
23782392
return false;

0 commit comments

Comments
 (0)