diff --git a/llamafile/tinyblas_cpu.h b/llamafile/tinyblas_cpu.h index bf37ade822..da99d1b6a5 100644 --- a/llamafile/tinyblas_cpu.h +++ b/llamafile/tinyblas_cpu.h @@ -550,13 +550,22 @@ class tinyBLAS { D Cv[RN][RM] = {}; for (long l = 0; l < KN * CHUNK * 4; l += KN) -#pragma GCC unroll 100 - for (j = 0; j < RN; ++j) -#pragma GCC unroll 100 + if constexpr (RM<=RN) +# pragma GCC unroll 100 + for (j = 0; j < RN; ++j) +# pragma GCC unroll 100 + for (i = 0; i < RM; ++i) + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk + l)), // + load(INDEX(B, ldb, jj + j, chunk + l)), // + Cv[j][i]); + else +# pragma GCC unroll 100 for (i = 0; i < RM; ++i) - Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk + l)), // - load(INDEX(B, ldb, jj + j, chunk + l)), // - Cv[j][i]); +# pragma GCC unroll 100 + for (j = 0; j < RN; ++j) + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk + l)), // + load(INDEX(B, ldb, jj + j, chunk + l)), // + Cv[j][i]); for (rule = bsr(step & -step); --rule;) for (--sp, j = 0; j < RN; ++j) @@ -570,13 +579,23 @@ class tinyBLAS { D Cv[RN][RM] = {}; for (; chunk + KN <= k; chunk += KN) -#pragma GCC unroll 100 - for (j = 0; j < RN; ++j) -#pragma GCC unroll 100 + if constexpr (RM<=RN) +# pragma GCC unroll 100 + for (j = 0; j < RN; ++j) +# pragma GCC unroll 100 + for (i = 0; i < RM; ++i) + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk)), // + load(INDEX(B, ldb, jj + j, chunk)), // + Cv[j][i]); + else +# pragma GCC unroll 100 for (i = 0; i < RM; ++i) - Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk)), // - load(INDEX(B, ldb, jj + j, chunk)), // - Cv[j][i]); +# pragma GCC unroll 100 + for (j = 0; j < RN; ++j) + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, chunk)), // + load(INDEX(B, ldb, jj + j, chunk)), // + Cv[j][i]); + while (sp--) for (j = 0; j < RN; ++j) @@ -589,11 +608,18 @@ class tinyBLAS { Cf[j][i] = hsum(Cv[j][i]); for (; chunk < k; ++chunk) - for (j = 0; j < RN; ++j) + if constexpr (RM<=RN) + for (j = 0; j < RN; ++j) + for (i = 0; i < RM; ++i) + Cf[j][i] = fmaf(load(INDEX(A, lda, ii + i, chunk)), // + load(INDEX(B, ldb, jj + j, chunk)), // + Cf[j][i]); + else for (i = 0; i < RM; ++i) - Cf[j][i] = fmaf(load(INDEX(A, lda, ii + i, chunk)), // - load(INDEX(B, ldb, jj + j, chunk)), // - Cf[j][i]); + for (j = 0; j < RN; ++j) + Cf[j][i] = fmaf(load(INDEX(A, lda, ii + i, chunk)), // + load(INDEX(B, ldb, jj + j, chunk)), // + Cf[j][i]); for (j = 0; j < RN; ++j) for (i = 0; i < RM; ++i)