4949#pragma GCC diagnostic ignored "-Wpedantic"
5050#pragma GCC diagnostic ignored "-Wignored-attributes"
5151
52- #define CHUNK 8
52+ #define CHUNK 16
5353#define ROW_ALIGN 64
5454#define MATRIX_ALIGN 4096
5555#define MAX_ALIGN 4096
@@ -416,6 +416,13 @@ inline void store(ggml_bf16_t *p, float f) {
416416// //////////////////////////////////////////////////////////////////////////////////////////////////
417417// FLOATING POINT MATRIX MULTIPLICATION
418418
419+ template <int M>
420+ static long BLOCK_SIZE (long m) {
421+ if (m % M == 0 ) return M;
422+ const long NB_BLOC_M = (m + M - 1 ) / M;
423+ return (m / NB_BLOC_M) + 1 ;
424+ }
425+
419426template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
420427class tinyBLAS {
421428 public:
@@ -424,180 +431,169 @@ class tinyBLAS {
424431 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
425432 }
426433
427- void matmul (long m, long n) {
428- mnpack (0 , m, 0 , n);
429- }
430-
431- private:
432- NOINLINE void mnpack (long m0, long m, long n0, long n) {
433- long mc, nc, mp, np;
434-
434+ bool matmul (long m, long n) {
435435#if VECTOR_REGISTERS == 32
436- switch ((MIN (m - m0, 5 ) << 4 ) | MIN (n - n0, 5 )) {
437- case 0x55 :
438- mc = 5 ;
439- nc = 5 ;
440- gemm<5 , 5 >(m0, m, n0, n);
441- break ;
442- case 0x54 :
443- case 0x53 :
444- case 0x52 :
445- case 0x45 :
446- case 0x44 :
447- case 0x43 :
448- case 0x42 :
449- case 0x35 :
450- case 0x34 :
451- case 0x33 :
452- case 0x32 :
453- case 0x25 :
454- case 0x24 :
455- case 0x23 :
456- case 0x22 :
457- mc = 2 ;
458- nc = 2 ;
459- gemm<2 , 2 >(m0, m, n0, n);
460- break ;
461- case 0x51 :
462- case 0x41 :
463- case 0x31 :
464- case 0x21 :
465- mc = 2 ;
466- nc = 1 ;
467- gemm<2 , 1 >(m0, m, n0, n);
468- break ;
469- case 0x15 :
470- case 0x14 :
471- case 0x13 :
472- case 0x12 :
473- mc = 1 ;
474- nc = 2 ;
475- gemm<1 , 2 >(m0, m, n0, n);
476- break ;
477- case 0x11 :
478- mc = 1 ;
479- nc = 1 ;
480- gemm<1 , 1 >(m0, m, n0, n);
481- break ;
482- default :
483- return ;
436+ if (m % 8 == 0 && n < 4 ) {
437+ mnpack<8 , 3 , 1 >(m, n, n);
438+ return true ;
484439 }
485- #endif
486-
487- #if VECTOR_REGISTERS == 16
488- switch ((MIN (m - m0, 4 ) << 4 ) | MIN (n - n0, 3 )) {
489- case 0x43 :
490- mc = 4 ;
491- nc = 3 ;
492- gemm<4 , 3 >(m0, m, n0, n);
493- break ;
494- case 0x42 :
495- case 0x33 :
496- case 0x32 :
497- case 0x23 :
498- case 0x22 :
499- mc = 2 ;
500- nc = 2 ;
501- gemm<2 , 2 >(m0, m, n0, n);
502- break ;
503- case 0x41 :
504- case 0x31 :
505- case 0x21 :
506- mc = 2 ;
507- nc = 1 ;
508- gemm<2 , 1 >(m0, m, n0, n);
509- break ;
510- case 0x13 :
511- case 0x12 :
512- mc = 1 ;
513- nc = 2 ;
514- gemm<1 , 2 >(m0, m, n0, n);
515- break ;
516- case 0x11 :
517- mc = 1 ;
518- nc = 1 ;
519- gemm<1 , 1 >(m0, m, n0, n);
520- break ;
521- default :
522- return ;
440+ if (m % 16 == 0 ) {
441+ const long SIZE_N = BLOCK_SIZE<6 >(n);
442+ mnpack<4 , 6 , 4 >(m, n, SIZE_N);
443+ return true ;
444+ }
445+ if (m % 8 == 0 ) {
446+ const long SIZE_N = BLOCK_SIZE<6 >(n);
447+ mnpack<4 , 6 , 2 >(m, n, SIZE_N);
448+ return true ;
449+ }
450+ if (m % 4 == 0 ) {
451+ const long SIZE_N = BLOCK_SIZE<6 >(n);
452+ mnpack<4 , 6 , 1 >(m, n, SIZE_N);
453+ return true ;
454+ }
455+ #else // VECTOR_REGISTERS == 16
456+ if (m % 4 == 0 && n < 3 ) {
457+ mnpack<4 , 2 , 1 >(m, n, n);
458+ return true ;
459+ }
460+ if (m % 16 == 0 ) {
461+ const long SIZE_N = BLOCK_SIZE<3 >(n);
462+ mnpack<4 , 3 , 4 >(m, n, SIZE_N);
463+ return true ;
464+ }
465+ if (m % 8 == 0 ) {
466+ const long SIZE_N = BLOCK_SIZE<3 >(n);
467+ mnpack<4 , 3 , 2 >(m, n, SIZE_N);
468+ return true ;
469+ }
470+ if (m % 4 == 0 ) {
471+ const long SIZE_N = BLOCK_SIZE<3 >(n);
472+ mnpack<4 , 3 , 1 >(m, n, SIZE_N);
473+ return true ;
523474 }
524475#endif
476+ return false ;
477+ }
525478
526- mp = m0 + (m - m0) / mc * mc;
527- np = n0 + (n - n0) / nc * nc;
528- mnpack (mp, m, n0, np);
529- mnpack (m0, m, np, n);
479+ private:
480+ template <int RM, int RN, int BM>
481+ inline void mnpack (long m, long n, long SIZE_N) {
482+ if (SIZE_N == RN) {
483+ return gemm<RM, RN, BM>(m, n);
484+ }
485+ if constexpr (RN > 1 ) {
486+ return mnpack<RM, RN-1 , BM>(m, n, SIZE_N);
487+ // } else {
488+ // GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
489+ // GGML_ASSERT(false); // we have miss something.
490+ }
530491 }
531492
532493 template <int RM, int RN>
533- NOINLINE void gemm (long m0, long m, long n0, long n) {
534- D stack[bsr (k / CHUNK + 1 ) + 1 ][RN][RM];
535- long ytiles = RM > 1 ? (m - m0) / RM : 1 ;
536- long xtiles = RN > 1 ? (n - n0) / RN : 1 ;
537- long tiles = xtiles * ytiles;
538- long duty = (tiles + nth - 1 ) / nth;
539- long start = duty * ith;
540- long end = start + duty;
541- if (end > tiles)
542- end = tiles;
543- for (long job = start; job < end; ++job) {
544- long ii = m0 + job / xtiles * RM;
545- long jj = n0 + job % xtiles * RN;
546-
547- size_t chunk, sp = 0 ;
548- int i, j, rule, step = 2 ;
549- for (chunk = 0 ; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4 , step += 2 , ++sp) {
550-
551- D Cv[RN][RM] = {};
552- for (long l = 0 ; l < KN * CHUNK * 4 ; l += KN)
494+ inline void gemm_bloc (long ii, long jj, long l, D Cv[RN][RM]) {
495+ // help compiler for op order.
496+ if constexpr (RM <= RN) {
497+ V Av[RM];
553498#pragma GCC unroll 100
554- for (j = 0 ; j < RN; ++j)
499+ for (int64_t i = 0 ; i < RM; ++i) {
500+ Av[i] = load<V>(A + lda * (ii + i) + l);
501+ }
555502#pragma GCC unroll 100
556- for (i = 0 ; i < RM; ++i)
557- Cv[j][i] = madd (load<V>(INDEX (A, lda, ii + i, chunk + l)), //
558- load<V>(INDEX (B, ldb, jj + j, chunk + l)), //
559- Cv[j][i]);
560-
561- for (rule = bsr (step & -step); --rule;)
562- for (--sp, j = 0 ; j < RN; ++j)
563- for (i = 0 ; i < RM; ++i)
564- Cv[j][i] += stack[sp][j][i];
565-
566- for (j = 0 ; j < RN; ++j)
567- for (i = 0 ; i < RM; ++i)
568- stack[sp][j][i] = Cv[j][i];
503+ for (int64_t j = 0 ; j < RN; ++j) {
504+ V Bv = load<V>(B + ldb * (jj + j) + l);
505+ #pragma GCC unroll 100
506+ for (int64_t i = 0 ; i < RM; ++i) {
507+ Cv[j][i] = madd (Av[i], Bv, Cv[j][i]);
508+ }
569509 }
570-
571- D Cv[RN][RM] = {};
572- for (; chunk + KN <= k; chunk += KN)
510+ } else {
511+ V Bv[RN];
573512#pragma GCC unroll 100
574- for (j = 0 ; j < RN; ++j)
513+ for (int64_t j = 0 ; j < RN; ++j) {
514+ Bv[j] = load<V>(B + ldb * (jj + j) + l);
515+ }
575516#pragma GCC unroll 100
576- for (i = 0 ; i < RM; ++i)
577- Cv[j][i] = madd (load<V>(INDEX (A, lda, ii + i, chunk)), //
578- load<V>(INDEX (B, ldb, jj + j, chunk)), //
579- Cv[j][i]);
517+ for (int64_t i = 0 ; i < RM; ++i) {
518+ V Av = load<V>(A + lda * (ii + i) + l);
519+ #pragma GCC unroll 100
520+ for (int64_t j = 0 ; j < RN; ++j) {
521+ Cv[j][i] = madd (Av, Bv[j], Cv[j][i]);
522+ }
523+ }
524+ }
525+ }
526+
527+ template <int RM, int RN>
528+ inline void gemm_bloc (long ii, long jj) {
529+ D stack[bsr (k / CHUNK + 1 ) + 1 ][RN][RM];
530+ long chunk, sp = 0 ;
531+ int i, j, rule, step = 2 ;
532+ for (chunk = 0 ; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4 , step += 2 , ++sp) {
580533
581- while (sp--)
582- for (j = 0 ; j < RN; ++j)
534+ D Cv[RN][RM] = {};
535+ for (long l = 0 ; l < KN * CHUNK * 4 ; l += KN)
536+ gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);
537+
538+ for (rule = bsr (step & -step); --rule;)
539+ for (--sp, j = 0 ; j < RN; ++j)
583540 for (i = 0 ; i < RM; ++i)
584541 Cv[j][i] += stack[sp][j][i];
585542
586- float Cf[RN][RM];
587543 for (j = 0 ; j < RN; ++j)
588544 for (i = 0 ; i < RM; ++i)
589- Cf[j][i] = hsum (Cv[j][i]);
545+ stack[sp][j][i] = Cv[j][i];
546+ }
590547
591- for (; chunk < k; ++chunk)
592- for (j = 0 ; j < RN; ++j)
593- for (i = 0 ; i < RM; ++i)
594- Cf[j][i] = fmaf (load<float >(INDEX (A, lda, ii + i, chunk)), //
595- load<float >(INDEX (B, ldb, jj + j, chunk)), //
596- Cf[j][i]);
548+ D Cv[RN][RM] = {};
549+ for (; chunk + KN <= k; chunk += KN)
550+ gemm_bloc<RM, RN>(ii, jj, chunk, Cv);
551+
552+ while (sp--)
553+ for (j = 0 ; j < RN; ++j)
554+ for (i = 0 ; i < RM; ++i)
555+ Cv[j][i] += stack[sp][j][i];
556+
557+ float Cf[RN][RM];
558+ for (j = 0 ; j < RN; ++j)
559+ for (i = 0 ; i < RM; ++i)
560+ Cf[j][i] = hsum (Cv[j][i]);
597561
562+ for (; chunk < k; ++chunk)
598563 for (j = 0 ; j < RN; ++j)
599564 for (i = 0 ; i < RM; ++i)
600- store (INDEX (C, ldc, jj + j, ii + i), Cf[j][i]);
565+ Cf[j][i] = fmaf (load<float >(INDEX (A, lda, ii + i, chunk)), //
566+ load<float >(INDEX (B, ldb, jj + j, chunk)), //
567+ Cf[j][i]);
568+
569+ for (j = 0 ; j < RN; ++j)
570+ for (i = 0 ; i < RM; ++i)
571+ store (INDEX (C, ldc, jj + j, ii + i), Cf[j][i]);
572+ }
573+
574+ template <int RM, int RN, int BM>
575+ NOINLINE void gemm (long m, long n) {
576+ // GGML_ASSERT(m % (RM * BM) == 0);
577+ const long ytiles = m / (RM * BM);
578+ const long xtiles = (n + RN -1 ) / RN;
579+ const long jj_RN = (xtiles - (xtiles * RN - n));
580+
581+ long tiles = xtiles * ytiles;
582+ long duty = (tiles + nth - 1 ) / nth;
583+ long start = duty * ith;
584+ long end = start + duty;
585+ if (end > tiles)
586+ end = tiles;
587+ for (int64_t job = start; job < end; ++job) {
588+ const int64_t ii = job / xtiles;
589+ const int64_t jj = job % xtiles;
590+ for (int64_t bi = 0 ; bi < BM; ++bi) {
591+ if (jj < jj_RN) {
592+ gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
593+ } else if constexpr (RN > 1 ) {
594+ gemm_bloc<RM, RN - 1 >((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1 ));
595+ }
596+ }
601597 }
602598 }
603599
0 commit comments