Skip to content

Commit 878fc0d

Browse files
committed
more perfo with llamafile tinyblas
- change dispache strategie (thanks: ikawrakow/ik_llama.cpp#71 ) - more cache freindly
1 parent 8fa1702 commit 878fc0d

File tree

2 files changed

+177
-182
lines changed

2 files changed

+177
-182
lines changed

llamafile/tinyblas_cpu.h

Lines changed: 148 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#pragma GCC diagnostic ignored "-Wpedantic"
5050
#pragma GCC diagnostic ignored "-Wignored-attributes"
5151

52-
#define CHUNK 8
52+
#define CHUNK 16
5353
#define ROW_ALIGN 64
5454
#define MATRIX_ALIGN 4096
5555
#define MAX_ALIGN 4096
@@ -416,6 +416,13 @@ inline void store(ggml_bf16_t *p, float f) {
416416
////////////////////////////////////////////////////////////////////////////////////////////////////
417417
// FLOATING POINT MATRIX MULTIPLICATION
418418

419+
template <int M>
420+
static long BLOCK_SIZE(long m) {
421+
if (m % M == 0) return M;
422+
const long NB_BLOC_M = (m + M - 1) / M;
423+
return (m / NB_BLOC_M) + 1;
424+
}
425+
419426
template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
420427
class tinyBLAS {
421428
public:
@@ -424,180 +431,169 @@ class tinyBLAS {
424431
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
425432
}
426433

427-
void matmul(long m, long n) {
428-
mnpack(0, m, 0, n);
429-
}
430-
431-
private:
432-
NOINLINE void mnpack(long m0, long m, long n0, long n) {
433-
long mc, nc, mp, np;
434-
434+
bool matmul(long m, long n) {
435435
#if VECTOR_REGISTERS == 32
436-
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
437-
case 0x55:
438-
mc = 5;
439-
nc = 5;
440-
gemm<5, 5>(m0, m, n0, n);
441-
break;
442-
case 0x54:
443-
case 0x53:
444-
case 0x52:
445-
case 0x45:
446-
case 0x44:
447-
case 0x43:
448-
case 0x42:
449-
case 0x35:
450-
case 0x34:
451-
case 0x33:
452-
case 0x32:
453-
case 0x25:
454-
case 0x24:
455-
case 0x23:
456-
case 0x22:
457-
mc = 2;
458-
nc = 2;
459-
gemm<2, 2>(m0, m, n0, n);
460-
break;
461-
case 0x51:
462-
case 0x41:
463-
case 0x31:
464-
case 0x21:
465-
mc = 2;
466-
nc = 1;
467-
gemm<2, 1>(m0, m, n0, n);
468-
break;
469-
case 0x15:
470-
case 0x14:
471-
case 0x13:
472-
case 0x12:
473-
mc = 1;
474-
nc = 2;
475-
gemm<1, 2>(m0, m, n0, n);
476-
break;
477-
case 0x11:
478-
mc = 1;
479-
nc = 1;
480-
gemm<1, 1>(m0, m, n0, n);
481-
break;
482-
default:
483-
return;
436+
if (m % 8 == 0 && n < 4) {
437+
mnpack<8, 3, 1>(m, n, n);
438+
return true;
484439
}
485-
#endif
486-
487-
#if VECTOR_REGISTERS == 16
488-
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
489-
case 0x43:
490-
mc = 4;
491-
nc = 3;
492-
gemm<4, 3>(m0, m, n0, n);
493-
break;
494-
case 0x42:
495-
case 0x33:
496-
case 0x32:
497-
case 0x23:
498-
case 0x22:
499-
mc = 2;
500-
nc = 2;
501-
gemm<2, 2>(m0, m, n0, n);
502-
break;
503-
case 0x41:
504-
case 0x31:
505-
case 0x21:
506-
mc = 2;
507-
nc = 1;
508-
gemm<2, 1>(m0, m, n0, n);
509-
break;
510-
case 0x13:
511-
case 0x12:
512-
mc = 1;
513-
nc = 2;
514-
gemm<1, 2>(m0, m, n0, n);
515-
break;
516-
case 0x11:
517-
mc = 1;
518-
nc = 1;
519-
gemm<1, 1>(m0, m, n0, n);
520-
break;
521-
default:
522-
return;
440+
if (m % 16 == 0) {
441+
const long SIZE_N = BLOCK_SIZE<6>(n);
442+
mnpack<4, 6, 4>(m, n, SIZE_N);
443+
return true;
444+
}
445+
if (m % 8 == 0) {
446+
const long SIZE_N = BLOCK_SIZE<6>(n);
447+
mnpack<4, 6, 2>(m, n, SIZE_N);
448+
return true;
449+
}
450+
if (m % 4 == 0) {
451+
const long SIZE_N = BLOCK_SIZE<6>(n);
452+
mnpack<4, 6, 1>(m, n, SIZE_N);
453+
return true;
454+
}
455+
#else // VECTOR_REGISTERS == 16
456+
if (m % 4 == 0 && n < 3) {
457+
mnpack<4, 2, 1>(m, n, n);
458+
return true;
459+
}
460+
if (m % 16 == 0) {
461+
const long SIZE_N = BLOCK_SIZE<3>(n);
462+
mnpack<4, 3, 4>(m, n, SIZE_N);
463+
return true;
464+
}
465+
if (m % 8 == 0) {
466+
const long SIZE_N = BLOCK_SIZE<3>(n);
467+
mnpack<4, 3, 2>(m, n, SIZE_N);
468+
return true;
469+
}
470+
if (m % 4 == 0) {
471+
const long SIZE_N = BLOCK_SIZE<3>(n);
472+
mnpack<4, 3, 1>(m, n, SIZE_N);
473+
return true;
523474
}
524475
#endif
476+
return false;
477+
}
525478

526-
mp = m0 + (m - m0) / mc * mc;
527-
np = n0 + (n - n0) / nc * nc;
528-
mnpack(mp, m, n0, np);
529-
mnpack(m0, m, np, n);
479+
private:
480+
template <int RM, int RN, int BM>
481+
inline void mnpack(long m, long n, long SIZE_N) {
482+
if (SIZE_N == RN) {
483+
return gemm<RM, RN, BM>(m, n);
484+
}
485+
if constexpr (RN > 1) {
486+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
487+
//} else {
488+
// GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
489+
// GGML_ASSERT(false); // we have miss something.
490+
}
530491
}
531492

532493
template <int RM, int RN>
533-
NOINLINE void gemm(long m0, long m, long n0, long n) {
534-
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
535-
long ytiles = RM > 1 ? (m - m0) / RM : 1;
536-
long xtiles = RN > 1 ? (n - n0) / RN : 1;
537-
long tiles = xtiles * ytiles;
538-
long duty = (tiles + nth - 1) / nth;
539-
long start = duty * ith;
540-
long end = start + duty;
541-
if (end > tiles)
542-
end = tiles;
543-
for (long job = start; job < end; ++job) {
544-
long ii = m0 + job / xtiles * RM;
545-
long jj = n0 + job % xtiles * RN;
546-
547-
size_t chunk, sp = 0;
548-
int i, j, rule, step = 2;
549-
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {
550-
551-
D Cv[RN][RM] = {};
552-
for (long l = 0; l < KN * CHUNK * 4; l += KN)
494+
inline void gemm_bloc(long ii, long jj, long l, D Cv[RN][RM]) {
495+
// help compiler for op order.
496+
if constexpr (RM <= RN) {
497+
V Av[RM];
553498
#pragma GCC unroll 100
554-
for (j = 0; j < RN; ++j)
499+
for (int64_t i = 0; i < RM; ++i) {
500+
Av[i] = load<V>(A + lda * (ii + i) + l);
501+
}
555502
#pragma GCC unroll 100
556-
for (i = 0; i < RM; ++i)
557-
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk + l)), //
558-
load<V>(INDEX(B, ldb, jj + j, chunk + l)), //
559-
Cv[j][i]);
560-
561-
for (rule = bsr(step & -step); --rule;)
562-
for (--sp, j = 0; j < RN; ++j)
563-
for (i = 0; i < RM; ++i)
564-
Cv[j][i] += stack[sp][j][i];
565-
566-
for (j = 0; j < RN; ++j)
567-
for (i = 0; i < RM; ++i)
568-
stack[sp][j][i] = Cv[j][i];
503+
for (int64_t j = 0; j < RN; ++j) {
504+
V Bv = load<V>(B + ldb * (jj + j) + l);
505+
#pragma GCC unroll 100
506+
for (int64_t i = 0; i < RM; ++i) {
507+
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
508+
}
569509
}
570-
571-
D Cv[RN][RM] = {};
572-
for (; chunk + KN <= k; chunk += KN)
510+
} else {
511+
V Bv[RN];
573512
#pragma GCC unroll 100
574-
for (j = 0; j < RN; ++j)
513+
for (int64_t j = 0; j < RN; ++j) {
514+
Bv[j] = load<V>(B + ldb * (jj + j) + l);
515+
}
575516
#pragma GCC unroll 100
576-
for (i = 0; i < RM; ++i)
577-
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk)), //
578-
load<V>(INDEX(B, ldb, jj + j, chunk)), //
579-
Cv[j][i]);
517+
for (int64_t i = 0; i < RM; ++i) {
518+
V Av = load<V>(A + lda * (ii + i) + l);
519+
#pragma GCC unroll 100
520+
for (int64_t j = 0; j < RN; ++j) {
521+
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
522+
}
523+
}
524+
}
525+
}
526+
527+
template <int RM, int RN>
528+
inline void gemm_bloc(long ii, long jj) {
529+
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
530+
long chunk, sp = 0;
531+
int i, j, rule, step = 2;
532+
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {
580533

581-
while (sp--)
582-
for (j = 0; j < RN; ++j)
534+
D Cv[RN][RM] = {};
535+
for (long l = 0; l < KN * CHUNK * 4; l += KN)
536+
gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);
537+
538+
for (rule = bsr(step & -step); --rule;)
539+
for (--sp, j = 0; j < RN; ++j)
583540
for (i = 0; i < RM; ++i)
584541
Cv[j][i] += stack[sp][j][i];
585542

586-
float Cf[RN][RM];
587543
for (j = 0; j < RN; ++j)
588544
for (i = 0; i < RM; ++i)
589-
Cf[j][i] = hsum(Cv[j][i]);
545+
stack[sp][j][i] = Cv[j][i];
546+
}
590547

591-
for (; chunk < k; ++chunk)
592-
for (j = 0; j < RN; ++j)
593-
for (i = 0; i < RM; ++i)
594-
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
595-
load<float>(INDEX(B, ldb, jj + j, chunk)), //
596-
Cf[j][i]);
548+
D Cv[RN][RM] = {};
549+
for (; chunk + KN <= k; chunk += KN)
550+
gemm_bloc<RM, RN>(ii, jj, chunk, Cv);
551+
552+
while (sp--)
553+
for (j = 0; j < RN; ++j)
554+
for (i = 0; i < RM; ++i)
555+
Cv[j][i] += stack[sp][j][i];
556+
557+
float Cf[RN][RM];
558+
for (j = 0; j < RN; ++j)
559+
for (i = 0; i < RM; ++i)
560+
Cf[j][i] = hsum(Cv[j][i]);
597561

562+
for (; chunk < k; ++chunk)
598563
for (j = 0; j < RN; ++j)
599564
for (i = 0; i < RM; ++i)
600-
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
565+
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
566+
load<float>(INDEX(B, ldb, jj + j, chunk)), //
567+
Cf[j][i]);
568+
569+
for (j = 0; j < RN; ++j)
570+
for (i = 0; i < RM; ++i)
571+
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
572+
}
573+
574+
template <int RM, int RN, int BM>
575+
NOINLINE void gemm(long m, long n) {
576+
// GGML_ASSERT(m % (RM * BM) == 0);
577+
const long ytiles = m / (RM * BM);
578+
const long xtiles = (n + RN -1) / RN;
579+
const long jj_RN = (xtiles - (xtiles * RN - n));
580+
581+
long tiles = xtiles * ytiles;
582+
long duty = (tiles + nth - 1) / nth;
583+
long start = duty * ith;
584+
long end = start + duty;
585+
if (end > tiles)
586+
end = tiles;
587+
for (int64_t job = start; job < end; ++job) {
588+
const int64_t ii = job / xtiles;
589+
const int64_t jj = job % xtiles;
590+
for (int64_t bi = 0; bi < BM; ++bi) {
591+
if (jj < jj_RN) {
592+
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
593+
} else if constexpr (RN > 1) {
594+
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1));
595+
}
596+
}
601597
}
602598
}
603599

0 commit comments

Comments
 (0)