@@ -456,18 +456,105 @@ void main() {
456456
457457 tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
458458
459- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
460- sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
461-
462459 uint k_iters = (end_k - start_k + BK - 1) / BK;
463460
464461 fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
462+ store_scales(tid);
463+
464+ #ifdef MUL_MAT_ID
465+ if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
466+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
467+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
468+
469+ [[dont_unroll]]
470+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
471+
472+ if ((block_k % QUANT_K) == 0) {
473+ store_scales(tid);
474+ }
475+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
476+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
477+ }
478+
479+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
480+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
481+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
482+
483+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
484+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
485+
486+ sum = coopMatMulAdd(mat_a, mat_b, sum);
487+ } else {
488+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
489+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
490+
491+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
492+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
493+
494+ sum = coopMatMulAdd(mat_a, mat_b, sum);
495+ }
496+ }
497+
498+ // Convert from ACC_TYPE to D_TYPE
499+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
500+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
501+
502+ // Call callback to store each element, remapping row through shared memory
503+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
504+ return;
505+ }
506+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
507+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
508+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
509+
510+ [[dont_unroll]]
511+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
512+
513+ if ((block_k % QUANT_K) == 0) {
514+ store_scales(tid);
515+ }
516+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
517+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
518+ }
519+
520+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
521+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
522+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
523+
524+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
525+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
526+
527+ sum = coopMatMulAdd(mat_a, mat_b, sum);
528+ } else {
529+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
530+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
531+
532+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
533+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
534+
535+ sum = coopMatMulAdd(mat_a, mat_b, sum);
536+ }
537+ }
538+
539+ // Convert from ACC_TYPE to D_TYPE
540+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
541+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
542+
543+ // Call callback to store each element, remapping row through shared memory
544+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
545+ return;
546+ }
547+ #endif
548+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
549+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
465550
466551 [[dont_unroll]]
467552 for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
468553
469- store_scales(tid);
470- if (block_k + BK < end_k) {
554+ if ((block_k % QUANT_K) == 0) {
555+ store_scales(tid);
556+ }
557+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
471558 fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
472559 }
473560
0 commit comments