@@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
2323layout (constant_id = 2) const uint BN = 64;
2424layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
2525
26+ layout (constant_id = 4) const bool enable_smaller_matrices = false;
27+ const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
28+ const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
29+ 
2630layout (push_constant) uniform parameter
2731{
2832    uint M;
@@ -168,15 +172,13 @@ void main() {
168172    const uint end_k = min(p.K, (ik + 1) * p.k_split);
169173#endif
170174
171-     coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
172-     sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
173- 
174175#ifdef MUL_MAT_ID
175176    uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
176177    uint pos_b = 0;
177178#else
178179    uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
179180    uint pos_b = batch_idx * p.batch_stride_b;
181+     uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
180182#endif
181183
182184    uint stride_a = p.stride_a / QUANT_K;
@@ -197,6 +199,7 @@ void main() {
197199    tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
198200    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
199201    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
202+     tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
200203
201204#if QUANT_K > 1
202205    tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@@ -232,16 +235,54 @@ void main() {
232235        tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
233236
234237        uint k_iters = (end_k - start_k + BK - 1) / BK;
238+         if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
239+             coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240+             for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
235241
236-         for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
242+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
237244
238-             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
239-             coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
245+                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246+                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
247+ 
248+                 sum = coopMatMulAdd(mat_a, mat_b, sum);
249+             }
250+             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
251+ 
252+             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
253+             return;
254+         } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
255+             coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256+             for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257+ 
258+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260+ 
261+                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262+                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
263+ 
264+                 sum = coopMatMulAdd(mat_a, mat_b, sum);
265+             }
266+             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
267+ 
268+             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
269+             return;
270+         } else {
271+             coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272+             for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273+ 
274+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275+                 coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
240276
241-             coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
242-             coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
277+                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278+                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
279+ 
280+                 sum = coopMatMulAdd(mat_a, mat_b, sum);
281+             }
282+             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
243283
244-             sum = coopMatMulAdd(mat_a, mat_b, sum);
284+             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
285+             return;
245286        }
246287    } else
247288#endif // !defined(MUL_MAT_ID)
@@ -254,6 +295,9 @@ void main() {
254295
255296        tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
256297
298+         coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
299+         sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
300+ 
257301        [[dont_unroll]]
258302        for (uint block_k = start_k; block_k < end_k; block_k += BK) {
259303
@@ -296,19 +340,16 @@ void main() {
296340                sum = coopMatMulAdd(mat_a, mat_b, sum);
297341            }
298342        }
299-     }
300343
301-     // Convert from ACC_TYPE to D_TYPE
302-     coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
303-     mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
344+          // Convert from ACC_TYPE to D_TYPE
345+          coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
346+          mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
304347
305348#ifdef MUL_MAT_ID
306-     // Call callback to store each element, remapping row through shared memory
307-     coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
349+          // Call callback to store each element, remapping row through shared memory
350+          coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
308351#else
309-     tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
310- 
311-     uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
312-     coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
352+         coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
313353#endif
354+     }
314355}
0 commit comments