@@ -456,18 +456,105 @@ void main() {
456
456
457
457
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
458
458
459
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
460
- sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
461
-
462
459
uint k_iters = (end_k - start_k + BK - 1) / BK;
463
460
464
461
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
462
+ store_scales(tid);
463
+
464
+ #ifdef MUL_MAT_ID
465
+ if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
466
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
467
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
468
+
469
+ [[dont_unroll]]
470
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
471
+
472
+ if ((block_k % QUANT_K) == 0) {
473
+ store_scales(tid);
474
+ }
475
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
476
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
477
+ }
478
+
479
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
480
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
481
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
482
+
483
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
484
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
485
+
486
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
487
+ } else {
488
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
489
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
490
+
491
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
492
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
493
+
494
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
495
+ }
496
+ }
497
+
498
+ // Convert from ACC_TYPE to D_TYPE
499
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
500
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
501
+
502
+ // Call callback to store each element, remapping row through shared memory
503
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
504
+ return;
505
+ }
506
+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
507
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
508
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
509
+
510
+ [[dont_unroll]]
511
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
512
+
513
+ if ((block_k % QUANT_K) == 0) {
514
+ store_scales(tid);
515
+ }
516
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
517
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
518
+ }
519
+
520
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
521
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
522
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
523
+
524
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
525
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
526
+
527
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
528
+ } else {
529
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
530
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
531
+
532
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
533
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
534
+
535
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
536
+ }
537
+ }
538
+
539
+ // Convert from ACC_TYPE to D_TYPE
540
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
541
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
542
+
543
+ // Call callback to store each element, remapping row through shared memory
544
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
545
+ return;
546
+ }
547
+ #endif
548
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
549
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
465
550
466
551
[[dont_unroll]]
467
552
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
468
553
469
- store_scales(tid);
470
- if (block_k + BK < end_k) {
554
+ if ((block_k % QUANT_K) == 0) {
555
+ store_scales(tid);
556
+ }
557
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
471
558
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
472
559
}
473
560
0 commit comments