@@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
451451 auto r3 = ne13 / ne03;
452452
453453 if (ne13 == 1 && Ny == 1 && r2 > 1) {
454+ if (Nx >= 256 && Nx%32 == 0) {
455+ int nx32 = Nx/32;
456+ int nchunk = nx32*ne02;
457+ if (r2 <= 8) {
458+ MulMat mm;
459+ if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
460+ int nx64 = Nx/64;
461+ int nchunk64 = nx64*ne02;
462+ for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
463+ int i02 = ichunk/nx64;
464+ int ix = 64*(ichunk - i02*nx64);
465+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
466+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
467+ }
468+ int ix0 = 64*nx64;
469+ if (ix0 < Nx) {
470+ nx32 -= 2*nx64;
471+ nchunk = nx32*ne02;
472+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
473+ int i02 = ichunk/nx32;
474+ int ix = ix0 + 32*(ichunk - i02*nx32);
475+ DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
476+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
477+ }
478+ }
479+ //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
480+ // int i02 = ichunk/nx32;
481+ // int ix = 32*(ichunk - i02*nx32);
482+ // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
483+ // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
484+ //}
485+ return true;
486+ }
487+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
488+ int i02 = ichunk/nx32;
489+ int ix = ichunk - i02*nx32;
490+ if (!iqk_mul_mat(32, r2, ne00,
491+ typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA,
492+ typeB, (const char *)B + i02*r2*nb12, nb12,
493+ C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false;
494+
495+ }
496+ return true;
497+ }
498+ //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
454499 int gcd = simple_gcd(ne02, nth);
455500 int counter = 0;
456501 for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
1715317198 FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
1715417199 fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
1715517200 }
17201+ else if (nq1 >= 4) {
17202+ FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
17203+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17204+ }
17205+ else if (nq1 >= 2) {
17206+ FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
17207+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17208+ }
1715617209 else {
1715717210 FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
1715817211 fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
0 commit comments