1414#extension GL_EXT_buffer_reference : enable
1515#extension GL_KHR_shader_subgroup_ballot : enable
1616#extension GL_KHR_shader_subgroup_vote : enable
17+ #ifdef DATA_A_BF16
18+ #extension GL_EXT_bfloat16 : enable
19+ #endif
1720
1821#include "types.comp"
1922
@@ -70,6 +73,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
7073#define DECODEFUNCA
7174#endif
7275
76+ #if defined(DATA_A_BF16)
77+ #define MAT_TYPE bfloat16_t
78+ #else
79+ #define MAT_TYPE FLOAT_TYPE
80+ #endif
81+
7382#ifdef MUL_MAT_ID
7483layout (binding = 3) readonly buffer IDS {int data_ids[];};
7584
@@ -239,8 +248,8 @@ void main() {
239248 coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240249 for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
241250
242- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
251+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
252+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
244253
245254 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246255 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -255,8 +264,8 @@ void main() {
255264 coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256265 for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257266
258- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
267+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
268+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260269
261270 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262271 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -271,8 +280,8 @@ void main() {
271280 coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272281 for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273282
274- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
283+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
276285
277286 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278287 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -301,8 +310,8 @@ void main() {
301310 [[dont_unroll]]
302311 for (uint block_k = start_k; block_k < end_k; block_k += BK) {
303312
304- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
305- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
313+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
306315
307316 // Clamping is expensive, so detect different code paths for each combination
308317 // of A and B needing clamping.
0 commit comments