1414#extension GL_EXT_buffer_reference : enable
1515#extension GL_KHR_shader_subgroup_ballot : enable
1616#extension GL_KHR_shader_subgroup_vote : enable
17+ #ifdef DATA_A_BF16
18+ #extension GL_EXT_bfloat16 : enable
19+ #endif
1720
1821#include "types.comp"
1922
@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
8083#define store_scales(a)
8184#endif
8285
86+ #if defined(DATA_A_BF16)
87+ #define MAT_TYPE bfloat16_t
88+ #else
89+ #define MAT_TYPE FLOAT_TYPE
90+ #endif
91+
8392#ifdef MUL_MAT_ID
8493layout (binding = 3) readonly buffer IDS {int data_ids[];};
8594
@@ -271,8 +280,8 @@ void main() {
271280
272281 // Manually partial unroll
273282 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
274- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
283+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
276285
277286 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278287 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -286,8 +295,8 @@ void main() {
286295 store_scales(tid);
287296 }
288297 while (block_k < end_k) {
289- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
290- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
298+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
299+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
291300
292301 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
293302 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -310,8 +319,8 @@ void main() {
310319
311320 // Manually partial unroll
312321 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
313- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
322+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
323+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
315324
316325 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
317326 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -325,8 +334,8 @@ void main() {
325334 store_scales(tid);
326335 }
327336 while (block_k < end_k) {
328- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
329- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
337+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
338+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
330339
331340 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
332341 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -350,8 +359,8 @@ void main() {
350359
351360 // Manually partial unroll
352361 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
353- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
354- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
362+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
363+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
355364
356365 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
357366 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -365,8 +374,8 @@ void main() {
365374 store_scales(tid);
366375 }
367376 while (block_k < end_k) {
368- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
369- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
377+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
378+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
370379
371380 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
372381 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -405,8 +414,8 @@ void main() {
405414 fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
406415 }
407416
408- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
409- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
417+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
410419
411420 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
412421#ifdef MUL_MAT_ID
0 commit comments