4949#pragma GCC diagnostic ignored "-Wpedantic"
5050#pragma GCC diagnostic ignored "-Wignored-attributes"
5151
52- #define CHUNK 8
52+ #define CHUNK 16
5353#define ROW_ALIGN 64
5454#define MATRIX_ALIGN 4096
5555#define MAX_ALIGN 4096
@@ -416,6 +416,12 @@ inline void store(ggml_bf16_t *p, float f) {
416416// //////////////////////////////////////////////////////////////////////////////////////////////////
417417// FLOATING POINT MATRIX MULTIPLICATION
418418
419+ template <int M>
420+ static long BLOCK_SIZE (long m) {
421+ const long NB_BLOC_M = (m + M - 1 ) / M;
422+ return (m % NB_BLOC_M == 0 ) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1 ;
423+ }
424+
419425template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
420426class tinyBLAS {
421427 public:
@@ -424,180 +430,169 @@ class tinyBLAS {
424430 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
425431 }
426432
427- void matmul (long m, long n) {
428- mnpack (0 , m, 0 , n);
429- }
430-
431- private:
432- NOINLINE void mnpack (long m0, long m, long n0, long n) {
433- long mc, nc, mp, np;
434-
433+ bool matmul (long m, long n) {
435434#if VECTOR_REGISTERS == 32
436- switch ((MIN (m - m0, 5 ) << 4 ) | MIN (n - n0, 5 )) {
437- case 0x55 :
438- mc = 5 ;
439- nc = 5 ;
440- gemm<5 , 5 >(m0, m, n0, n);
441- break ;
442- case 0x54 :
443- case 0x53 :
444- case 0x52 :
445- case 0x45 :
446- case 0x44 :
447- case 0x43 :
448- case 0x42 :
449- case 0x35 :
450- case 0x34 :
451- case 0x33 :
452- case 0x32 :
453- case 0x25 :
454- case 0x24 :
455- case 0x23 :
456- case 0x22 :
457- mc = 2 ;
458- nc = 2 ;
459- gemm<2 , 2 >(m0, m, n0, n);
460- break ;
461- case 0x51 :
462- case 0x41 :
463- case 0x31 :
464- case 0x21 :
465- mc = 2 ;
466- nc = 1 ;
467- gemm<2 , 1 >(m0, m, n0, n);
468- break ;
469- case 0x15 :
470- case 0x14 :
471- case 0x13 :
472- case 0x12 :
473- mc = 1 ;
474- nc = 2 ;
475- gemm<1 , 2 >(m0, m, n0, n);
476- break ;
477- case 0x11 :
478- mc = 1 ;
479- nc = 1 ;
480- gemm<1 , 1 >(m0, m, n0, n);
481- break ;
482- default :
483- return ;
435+ if (m % 8 == 0 && n < 4 ) {
436+ mnpack<8 , 3 , 1 >(m, n, n);
437+ return true ;
484438 }
485- #endif
486-
487- #if VECTOR_REGISTERS == 16
488- switch ((MIN (m - m0, 4 ) << 4 ) | MIN (n - n0, 3 )) {
489- case 0x43 :
490- mc = 4 ;
491- nc = 3 ;
492- gemm<4 , 3 >(m0, m, n0, n);
493- break ;
494- case 0x42 :
495- case 0x33 :
496- case 0x32 :
497- case 0x23 :
498- case 0x22 :
499- mc = 2 ;
500- nc = 2 ;
501- gemm<2 , 2 >(m0, m, n0, n);
502- break ;
503- case 0x41 :
504- case 0x31 :
505- case 0x21 :
506- mc = 2 ;
507- nc = 1 ;
508- gemm<2 , 1 >(m0, m, n0, n);
509- break ;
510- case 0x13 :
511- case 0x12 :
512- mc = 1 ;
513- nc = 2 ;
514- gemm<1 , 2 >(m0, m, n0, n);
515- break ;
516- case 0x11 :
517- mc = 1 ;
518- nc = 1 ;
519- gemm<1 , 1 >(m0, m, n0, n);
520- break ;
521- default :
522- return ;
439+ if (m % 16 == 0 ) {
440+ const long SIZE_N = BLOCK_SIZE<6 >(n);
441+ mnpack<4 , 6 , 4 >(m, n, SIZE_N);
442+ return true ;
443+ }
444+ if (m % 8 == 0 ) {
445+ const long SIZE_N = BLOCK_SIZE<6 >(n);
446+ mnpack<4 , 6 , 2 >(m, n, SIZE_N);
447+ return true ;
448+ }
449+ if (m % 4 == 0 ) {
450+ const long SIZE_N = BLOCK_SIZE<6 >(n);
451+ mnpack<4 , 6 , 1 >(m, n, SIZE_N);
452+ return true ;
453+ }
454+ #else // VECTOR_REGISTERS == 16
455+ if (m % 4 == 0 && n < 3 ) {
456+ mnpack<4 , 2 , 1 >(m, n, n);
457+ return true ;
458+ }
459+ if (m % 16 == 0 ) {
460+ const long SIZE_N = BLOCK_SIZE<3 >(n);
461+ mnpack<4 , 3 , 4 >(m, n, SIZE_N);
462+ return true ;
463+ }
464+ if (m % 8 == 0 ) {
465+ const long SIZE_N = BLOCK_SIZE<3 >(n);
466+ mnpack<4 , 3 , 2 >(m, n, SIZE_N);
467+ return true ;
468+ }
469+ if (m % 4 == 0 ) {
470+ const long SIZE_N = BLOCK_SIZE<3 >(n);
471+ mnpack<4 , 3 , 1 >(m, n, SIZE_N);
472+ return true ;
523473 }
524474#endif
475+ return false ;
476+ }
525477
526- mp = m0 + (m - m0) / mc * mc;
527- np = n0 + (n - n0) / nc * nc;
528- mnpack (mp, m, n0, np);
529- mnpack (m0, m, np, n);
478+ private:
479+ template <int RM, int RN, int BM>
480+ inline void mnpack (long m, long n, long SIZE_N) {
481+ if (SIZE_N == RN) {
482+ return gemm<RM, RN, BM>(m, n);
483+ }
484+ if constexpr (RN > 1 ) {
485+ return mnpack<RM, RN-1 , BM>(m, n, SIZE_N);
486+ // } else {
487+ // GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
488+ // GGML_ASSERT(false); // we have miss something.
489+ }
530490 }
531491
532492 template <int RM, int RN>
533- NOINLINE void gemm (long m0, long m, long n0, long n) {
534- D stack[bsr (k / CHUNK + 1 ) + 1 ][RN][RM];
535- long ytiles = RM > 1 ? (m - m0) / RM : 1 ;
536- long xtiles = RN > 1 ? (n - n0) / RN : 1 ;
537- long tiles = xtiles * ytiles;
538- long duty = (tiles + nth - 1 ) / nth;
539- long start = duty * ith;
540- long end = start + duty;
541- if (end > tiles)
542- end = tiles;
543- for (long job = start; job < end; ++job) {
544- long ii = m0 + job / xtiles * RM;
545- long jj = n0 + job % xtiles * RN;
546-
547- size_t chunk, sp = 0 ;
548- int i, j, rule, step = 2 ;
549- for (chunk = 0 ; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4 , step += 2 , ++sp) {
550-
551- D Cv[RN][RM] = {};
552- for (long l = 0 ; l < KN * CHUNK * 4 ; l += KN)
493+ inline void gemm_bloc (long ii, long jj, long l, D Cv[RN][RM]) {
494+ // help compiler for op order.
495+ if constexpr (RM <= RN) {
496+ V Av[RM];
553497#pragma GCC unroll 100
554- for (j = 0 ; j < RN; ++j)
498+ for (int64_t i = 0 ; i < RM; ++i) {
499+ Av[i] = load<V>(A + lda * (ii + i) + l);
500+ }
555501#pragma GCC unroll 100
556- for (i = 0 ; i < RM; ++i)
557- Cv[j][i] = madd (load<V>(INDEX (A, lda, ii + i, chunk + l)), //
558- load<V>(INDEX (B, ldb, jj + j, chunk + l)), //
559- Cv[j][i]);
560-
561- for (rule = bsr (step & -step); --rule;)
562- for (--sp, j = 0 ; j < RN; ++j)
563- for (i = 0 ; i < RM; ++i)
564- Cv[j][i] += stack[sp][j][i];
565-
566- for (j = 0 ; j < RN; ++j)
567- for (i = 0 ; i < RM; ++i)
568- stack[sp][j][i] = Cv[j][i];
502+ for (int64_t j = 0 ; j < RN; ++j) {
503+ V Bv = load<V>(B + ldb * (jj + j) + l);
504+ #pragma GCC unroll 100
505+ for (int64_t i = 0 ; i < RM; ++i) {
506+ Cv[j][i] = madd (Av[i], Bv, Cv[j][i]);
507+ }
569508 }
570-
571- D Cv[RN][RM] = {};
572- for (; chunk + KN <= k; chunk += KN)
509+ } else {
510+ V Bv[RN];
573511#pragma GCC unroll 100
574- for (j = 0 ; j < RN; ++j)
512+ for (int64_t j = 0 ; j < RN; ++j) {
513+ Bv[j] = load<V>(B + ldb * (jj + j) + l);
514+ }
575515#pragma GCC unroll 100
576- for (i = 0 ; i < RM; ++i)
577- Cv[j][i] = madd (load<V>(INDEX (A, lda, ii + i, chunk)), //
578- load<V>(INDEX (B, ldb, jj + j, chunk)), //
579- Cv[j][i]);
516+ for (int64_t i = 0 ; i < RM; ++i) {
517+ V Av = load<V>(A + lda * (ii + i) + l);
518+ #pragma GCC unroll 100
519+ for (int64_t j = 0 ; j < RN; ++j) {
520+ Cv[j][i] = madd (Av, Bv[j], Cv[j][i]);
521+ }
522+ }
523+ }
524+ }
525+
526+ template <int RM, int RN>
527+ inline void gemm_bloc (long ii, long jj) {
528+ D stack[bsr (k / CHUNK + 1 ) + 1 ][RN][RM];
529+ long chunk, sp = 0 ;
530+ int i, j, rule, step = 2 ;
531+ for (chunk = 0 ; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4 , step += 2 , ++sp) {
580532
581- while (sp--)
582- for (j = 0 ; j < RN; ++j)
533+ D Cv[RN][RM] = {};
534+ for (long l = 0 ; l < KN * CHUNK * 4 ; l += KN)
535+ gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);
536+
537+ for (rule = bsr (step & -step); --rule;)
538+ for (--sp, j = 0 ; j < RN; ++j)
583539 for (i = 0 ; i < RM; ++i)
584540 Cv[j][i] += stack[sp][j][i];
585541
586- float Cf[RN][RM];
587542 for (j = 0 ; j < RN; ++j)
588543 for (i = 0 ; i < RM; ++i)
589- Cf[j][i] = hsum (Cv[j][i]);
544+ stack[sp][j][i] = Cv[j][i];
545+ }
590546
591- for (; chunk < k; ++chunk)
592- for (j = 0 ; j < RN; ++j)
593- for (i = 0 ; i < RM; ++i)
594- Cf[j][i] = fmaf (load<float >(INDEX (A, lda, ii + i, chunk)), //
595- load<float >(INDEX (B, ldb, jj + j, chunk)), //
596- Cf[j][i]);
547+ D Cv[RN][RM] = {};
548+ for (; chunk + KN <= k; chunk += KN)
549+ gemm_bloc<RM, RN>(ii, jj, chunk, Cv);
550+
551+ while (sp--)
552+ for (j = 0 ; j < RN; ++j)
553+ for (i = 0 ; i < RM; ++i)
554+ Cv[j][i] += stack[sp][j][i];
555+
556+ float Cf[RN][RM];
557+ for (j = 0 ; j < RN; ++j)
558+ for (i = 0 ; i < RM; ++i)
559+ Cf[j][i] = hsum (Cv[j][i]);
597560
561+ for (; chunk < k; ++chunk)
598562 for (j = 0 ; j < RN; ++j)
599563 for (i = 0 ; i < RM; ++i)
600- store (INDEX (C, ldc, jj + j, ii + i), Cf[j][i]);
564+ Cf[j][i] = fmaf (load<float >(INDEX (A, lda, ii + i, chunk)), //
565+ load<float >(INDEX (B, ldb, jj + j, chunk)), //
566+ Cf[j][i]);
567+
568+ for (j = 0 ; j < RN; ++j)
569+ for (i = 0 ; i < RM; ++i)
570+ store (INDEX (C, ldc, jj + j, ii + i), Cf[j][i]);
571+ }
572+
573+ template <int RM, int RN, int BM>
574+ NOINLINE void gemm (long m, long n) {
575+ // GGML_ASSERT(m % (RM * BM) == 0);
576+ const long ytiles = m / (RM * BM);
577+ const long xtiles = (n + RN -1 ) / RN;
578+ const long jj_RN = (xtiles - (xtiles * RN - n));
579+
580+ long tiles = xtiles * ytiles;
581+ long duty = (tiles + nth - 1 ) / nth;
582+ long start = duty * ith;
583+ long end = start + duty;
584+ if (end > tiles)
585+ end = tiles;
586+ for (int64_t job = start; job < end; ++job) {
587+ const int64_t ii = job / xtiles;
588+ const int64_t jj = job % xtiles;
589+ for (int64_t bi = 0 ; bi < BM; ++bi) {
590+ if (jj < jj_RN) {
591+ gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
592+ } else if constexpr (RN > 1 ) {
593+ gemm_bloc<RM, RN - 1 >((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1 ));
594+ }
595+ }
601596 }
602597 }
603598
0 commit comments