1414#extension GL_EXT_buffer_reference : enable
1515#extension GL_KHR_shader_subgroup_ballot : enable
1616#extension GL_KHR_shader_subgroup_vote : enable
17+ #ifdef DATA_A_BF16
18+ #extension GL_EXT_bfloat16 : enable
19+ #endif
1720
1821#include "types.comp"
1922
@@ -70,6 +73,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
7073#define DECODEFUNCA
7174#endif
7275
76+ #if defined(DATA_A_BF16)
77+ #define MAT_TYPE bfloat16_t
78+ #else
79+ #define MAT_TYPE FLOAT_TYPE
80+ #endif
81+ 
7382#ifdef MUL_MAT_ID
7483layout (binding = 3) readonly buffer IDS {int data_ids[];};
7584
@@ -239,8 +248,8 @@ void main() {
239248            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240249            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
241250
242-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
251+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
252+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
244253
245254                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246255                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -255,8 +264,8 @@ void main() {
255264            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256265            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257266
258-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
267+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
268+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260269
261270                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262271                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -271,8 +280,8 @@ void main() {
271280            coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272281            for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273282
274-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275-                 coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
283+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284+                 coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
276285
277286                coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278287                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -301,8 +310,8 @@ void main() {
301310        [[dont_unroll]]
302311        for (uint block_k = start_k; block_k < end_k; block_k += BK) {
303312
304-             coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
305-             coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
313+             coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314+             coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
306315
307316            // Clamping is expensive, so detect different code paths for each combination
308317            // of A and B needing clamping.
0 commit comments