Skip to content

Commit b542e64

Browse files
committed
more perfo with llamafile tinyblas
- change dispache strategie (thanks: ikawrakow/ik_llama.cpp#71 ) - more cache freindly
1 parent 8fa1702 commit b542e64

File tree

2 files changed

+176
-182
lines changed

2 files changed

+176
-182
lines changed

llamafile/tinyblas_cpu.h

Lines changed: 147 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#pragma GCC diagnostic ignored "-Wpedantic"
5050
#pragma GCC diagnostic ignored "-Wignored-attributes"
5151

52-
#define CHUNK 8
52+
#define CHUNK 16
5353
#define ROW_ALIGN 64
5454
#define MATRIX_ALIGN 4096
5555
#define MAX_ALIGN 4096
@@ -416,6 +416,12 @@ inline void store(ggml_bf16_t *p, float f) {
416416
////////////////////////////////////////////////////////////////////////////////////////////////////
417417
// FLOATING POINT MATRIX MULTIPLICATION
418418

419+
template <int M>
420+
static long BLOCK_SIZE(long m) {
421+
const long NB_BLOC_M = (m + M - 1) / M;
422+
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
423+
}
424+
419425
template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
420426
class tinyBLAS {
421427
public:
@@ -424,180 +430,169 @@ class tinyBLAS {
424430
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
425431
}
426432

427-
void matmul(long m, long n) {
428-
mnpack(0, m, 0, n);
429-
}
430-
431-
private:
432-
NOINLINE void mnpack(long m0, long m, long n0, long n) {
433-
long mc, nc, mp, np;
434-
433+
bool matmul(long m, long n) {
435434
#if VECTOR_REGISTERS == 32
436-
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
437-
case 0x55:
438-
mc = 5;
439-
nc = 5;
440-
gemm<5, 5>(m0, m, n0, n);
441-
break;
442-
case 0x54:
443-
case 0x53:
444-
case 0x52:
445-
case 0x45:
446-
case 0x44:
447-
case 0x43:
448-
case 0x42:
449-
case 0x35:
450-
case 0x34:
451-
case 0x33:
452-
case 0x32:
453-
case 0x25:
454-
case 0x24:
455-
case 0x23:
456-
case 0x22:
457-
mc = 2;
458-
nc = 2;
459-
gemm<2, 2>(m0, m, n0, n);
460-
break;
461-
case 0x51:
462-
case 0x41:
463-
case 0x31:
464-
case 0x21:
465-
mc = 2;
466-
nc = 1;
467-
gemm<2, 1>(m0, m, n0, n);
468-
break;
469-
case 0x15:
470-
case 0x14:
471-
case 0x13:
472-
case 0x12:
473-
mc = 1;
474-
nc = 2;
475-
gemm<1, 2>(m0, m, n0, n);
476-
break;
477-
case 0x11:
478-
mc = 1;
479-
nc = 1;
480-
gemm<1, 1>(m0, m, n0, n);
481-
break;
482-
default:
483-
return;
435+
if (m % 8 == 0 && n < 4) {
436+
mnpack<8, 3, 1>(m, n, n);
437+
return true;
484438
}
485-
#endif
486-
487-
#if VECTOR_REGISTERS == 16
488-
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
489-
case 0x43:
490-
mc = 4;
491-
nc = 3;
492-
gemm<4, 3>(m0, m, n0, n);
493-
break;
494-
case 0x42:
495-
case 0x33:
496-
case 0x32:
497-
case 0x23:
498-
case 0x22:
499-
mc = 2;
500-
nc = 2;
501-
gemm<2, 2>(m0, m, n0, n);
502-
break;
503-
case 0x41:
504-
case 0x31:
505-
case 0x21:
506-
mc = 2;
507-
nc = 1;
508-
gemm<2, 1>(m0, m, n0, n);
509-
break;
510-
case 0x13:
511-
case 0x12:
512-
mc = 1;
513-
nc = 2;
514-
gemm<1, 2>(m0, m, n0, n);
515-
break;
516-
case 0x11:
517-
mc = 1;
518-
nc = 1;
519-
gemm<1, 1>(m0, m, n0, n);
520-
break;
521-
default:
522-
return;
439+
if (m % 16 == 0) {
440+
const long SIZE_N = BLOCK_SIZE<6>(n);
441+
mnpack<4, 6, 4>(m, n, SIZE_N);
442+
return true;
443+
}
444+
if (m % 8 == 0) {
445+
const long SIZE_N = BLOCK_SIZE<6>(n);
446+
mnpack<4, 6, 2>(m, n, SIZE_N);
447+
return true;
448+
}
449+
if (m % 4 == 0) {
450+
const long SIZE_N = BLOCK_SIZE<6>(n);
451+
mnpack<4, 6, 1>(m, n, SIZE_N);
452+
return true;
453+
}
454+
#else // VECTOR_REGISTERS == 16
455+
if (m % 4 == 0 && n < 3) {
456+
mnpack<4, 2, 1>(m, n, n);
457+
return true;
458+
}
459+
if (m % 16 == 0) {
460+
const long SIZE_N = BLOCK_SIZE<3>(n);
461+
mnpack<4, 3, 4>(m, n, SIZE_N);
462+
return true;
463+
}
464+
if (m % 8 == 0) {
465+
const long SIZE_N = BLOCK_SIZE<3>(n);
466+
mnpack<4, 3, 2>(m, n, SIZE_N);
467+
return true;
468+
}
469+
if (m % 4 == 0) {
470+
const long SIZE_N = BLOCK_SIZE<3>(n);
471+
mnpack<4, 3, 1>(m, n, SIZE_N);
472+
return true;
523473
}
524474
#endif
475+
return false;
476+
}
525477

526-
mp = m0 + (m - m0) / mc * mc;
527-
np = n0 + (n - n0) / nc * nc;
528-
mnpack(mp, m, n0, np);
529-
mnpack(m0, m, np, n);
478+
private:
479+
template <int RM, int RN, int BM>
480+
inline void mnpack(long m, long n, long SIZE_N) {
481+
if (SIZE_N == RN) {
482+
return gemm<RM, RN, BM>(m, n);
483+
}
484+
if constexpr (RN > 1) {
485+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
486+
//} else {
487+
// GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
488+
// GGML_ASSERT(false); // we have miss something.
489+
}
530490
}
531491

532492
template <int RM, int RN>
533-
NOINLINE void gemm(long m0, long m, long n0, long n) {
534-
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
535-
long ytiles = RM > 1 ? (m - m0) / RM : 1;
536-
long xtiles = RN > 1 ? (n - n0) / RN : 1;
537-
long tiles = xtiles * ytiles;
538-
long duty = (tiles + nth - 1) / nth;
539-
long start = duty * ith;
540-
long end = start + duty;
541-
if (end > tiles)
542-
end = tiles;
543-
for (long job = start; job < end; ++job) {
544-
long ii = m0 + job / xtiles * RM;
545-
long jj = n0 + job % xtiles * RN;
546-
547-
size_t chunk, sp = 0;
548-
int i, j, rule, step = 2;
549-
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {
550-
551-
D Cv[RN][RM] = {};
552-
for (long l = 0; l < KN * CHUNK * 4; l += KN)
493+
inline void gemm_bloc(long ii, long jj, long l, D Cv[RN][RM]) {
494+
// help compiler for op order.
495+
if constexpr (RM <= RN) {
496+
V Av[RM];
553497
#pragma GCC unroll 100
554-
for (j = 0; j < RN; ++j)
498+
for (int64_t i = 0; i < RM; ++i) {
499+
Av[i] = load<V>(A + lda * (ii + i) + l);
500+
}
555501
#pragma GCC unroll 100
556-
for (i = 0; i < RM; ++i)
557-
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk + l)), //
558-
load<V>(INDEX(B, ldb, jj + j, chunk + l)), //
559-
Cv[j][i]);
560-
561-
for (rule = bsr(step & -step); --rule;)
562-
for (--sp, j = 0; j < RN; ++j)
563-
for (i = 0; i < RM; ++i)
564-
Cv[j][i] += stack[sp][j][i];
565-
566-
for (j = 0; j < RN; ++j)
567-
for (i = 0; i < RM; ++i)
568-
stack[sp][j][i] = Cv[j][i];
502+
for (int64_t j = 0; j < RN; ++j) {
503+
V Bv = load<V>(B + ldb * (jj + j) + l);
504+
#pragma GCC unroll 100
505+
for (int64_t i = 0; i < RM; ++i) {
506+
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
507+
}
569508
}
570-
571-
D Cv[RN][RM] = {};
572-
for (; chunk + KN <= k; chunk += KN)
509+
} else {
510+
V Bv[RN];
573511
#pragma GCC unroll 100
574-
for (j = 0; j < RN; ++j)
512+
for (int64_t j = 0; j < RN; ++j) {
513+
Bv[j] = load<V>(B + ldb * (jj + j) + l);
514+
}
575515
#pragma GCC unroll 100
576-
for (i = 0; i < RM; ++i)
577-
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk)), //
578-
load<V>(INDEX(B, ldb, jj + j, chunk)), //
579-
Cv[j][i]);
516+
for (int64_t i = 0; i < RM; ++i) {
517+
V Av = load<V>(A + lda * (ii + i) + l);
518+
#pragma GCC unroll 100
519+
for (int64_t j = 0; j < RN; ++j) {
520+
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
521+
}
522+
}
523+
}
524+
}
525+
526+
template <int RM, int RN>
527+
inline void gemm_bloc(long ii, long jj) {
528+
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
529+
long chunk, sp = 0;
530+
int i, j, rule, step = 2;
531+
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {
580532

581-
while (sp--)
582-
for (j = 0; j < RN; ++j)
533+
D Cv[RN][RM] = {};
534+
for (long l = 0; l < KN * CHUNK * 4; l += KN)
535+
gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);
536+
537+
for (rule = bsr(step & -step); --rule;)
538+
for (--sp, j = 0; j < RN; ++j)
583539
for (i = 0; i < RM; ++i)
584540
Cv[j][i] += stack[sp][j][i];
585541

586-
float Cf[RN][RM];
587542
for (j = 0; j < RN; ++j)
588543
for (i = 0; i < RM; ++i)
589-
Cf[j][i] = hsum(Cv[j][i]);
544+
stack[sp][j][i] = Cv[j][i];
545+
}
590546

591-
for (; chunk < k; ++chunk)
592-
for (j = 0; j < RN; ++j)
593-
for (i = 0; i < RM; ++i)
594-
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
595-
load<float>(INDEX(B, ldb, jj + j, chunk)), //
596-
Cf[j][i]);
547+
D Cv[RN][RM] = {};
548+
for (; chunk + KN <= k; chunk += KN)
549+
gemm_bloc<RM, RN>(ii, jj, chunk, Cv);
550+
551+
while (sp--)
552+
for (j = 0; j < RN; ++j)
553+
for (i = 0; i < RM; ++i)
554+
Cv[j][i] += stack[sp][j][i];
555+
556+
float Cf[RN][RM];
557+
for (j = 0; j < RN; ++j)
558+
for (i = 0; i < RM; ++i)
559+
Cf[j][i] = hsum(Cv[j][i]);
597560

561+
for (; chunk < k; ++chunk)
598562
for (j = 0; j < RN; ++j)
599563
for (i = 0; i < RM; ++i)
600-
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
564+
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
565+
load<float>(INDEX(B, ldb, jj + j, chunk)), //
566+
Cf[j][i]);
567+
568+
for (j = 0; j < RN; ++j)
569+
for (i = 0; i < RM; ++i)
570+
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
571+
}
572+
573+
template <int RM, int RN, int BM>
574+
NOINLINE void gemm(long m, long n) {
575+
// GGML_ASSERT(m % (RM * BM) == 0);
576+
const long ytiles = m / (RM * BM);
577+
const long xtiles = (n + RN -1) / RN;
578+
const long jj_RN = (xtiles - (xtiles * RN - n));
579+
580+
long tiles = xtiles * ytiles;
581+
long duty = (tiles + nth - 1) / nth;
582+
long start = duty * ith;
583+
long end = start + duty;
584+
if (end > tiles)
585+
end = tiles;
586+
for (int64_t job = start; job < end; ++job) {
587+
const int64_t ii = job / xtiles;
588+
const int64_t jj = job % xtiles;
589+
for (int64_t bi = 0; bi < BM; ++bi) {
590+
if (jj < jj_RN) {
591+
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
592+
} else if constexpr (RN > 1) {
593+
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1));
594+
}
595+
}
601596
}
602597
}
603598

0 commit comments

Comments
 (0)