@@ -162,17 +162,32 @@ void main() {
162162 _ne1 = 0;
163163 uint num_elements = p.nei1 * p.nei0;
164164
165- for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
165+ uint ids[16];
166+ uint iter = 0;
167+
168+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169+ // prefetch up to 16 elements
170+ if (iter == 0) {
171+ [[unroll]] for (uint k = 0; k < 16; ++k) {
172+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173+ bool in_range = i < num_elements;
174+ uint ii1 = i / p.nei0;
175+ uint ii0 = i % p.nei0;
176+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177+ }
178+ }
179+ uint i = j + gl_SubgroupInvocationID;
166180 bool in_range = i < num_elements;
167- uint ii0 = i % p.nei0;
168181 uint ii1 = i / p.nei0;
169- uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
182+ uint ii0 = i % p.nei0;
183+ uint id = ids[iter++];
170184 uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
171185 uint idx = subgroupBallotExclusiveBitCount(ballot);
172186 if (in_range && id == expert_idx) {
173187 row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
174188 }
175189 _ne1 += subgroupBallotBitCount(ballot);
190+ iter &= 15;
176191 }
177192 _ne1_sh = _ne1;
178193 }
@@ -414,17 +429,31 @@ void main() {
414429 fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415430 }
416431
417- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
432+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
433+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
434+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419435
420- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp , ir * BM, BM, block_k, BK) DECODEFUNCA);
436+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA , ir * BM, BM, block_k, BK) DECODEFUNCA);
421437#ifdef MUL_MAT_ID
422- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
438+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423439#else
424- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
440+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425441#endif
426442
427- sum = coopMatMulAdd(mat_a, mat_b, sum);
443+ sum = coopMatMulAdd(mat_a, mat_b, sum);
444+ } else {
445+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
446+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
447+
448+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
449+ #ifdef MUL_MAT_ID
450+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
451+ #else
452+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
453+ #endif
454+
455+ sum = coopMatMulAdd(mat_a, mat_b, sum);
456+ }
428457 }
429458
430459 // Convert from ACC_TYPE to D_TYPE
0 commit comments