14
14
#extension GL_EXT_buffer_reference : enable
15
15
#extension GL_KHR_shader_subgroup_ballot : enable
16
16
#extension GL_KHR_shader_subgroup_vote : enable
17
+ #ifdef DATA_A_BF16
18
+ #extension GL_EXT_bfloat16 : enable
19
+ #endif
17
20
18
21
#include "types.comp"
19
22
@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
80
83
#define store_scales(a)
81
84
#endif
82
85
86
+ #if defined(DATA_A_BF16)
87
+ #define MAT_TYPE bfloat16_t
88
+ #else
89
+ #define MAT_TYPE FLOAT_TYPE
90
+ #endif
91
+
83
92
#ifdef MUL_MAT_ID
84
93
layout (binding = 3) readonly buffer IDS {int data_ids[];};
85
94
@@ -271,8 +280,8 @@ void main() {
271
280
272
281
// Manually partial unroll
273
282
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
274
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
283
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
276
285
277
286
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278
287
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -286,8 +295,8 @@ void main() {
286
295
store_scales(tid);
287
296
}
288
297
while (block_k < end_k) {
289
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
290
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
298
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
299
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
291
300
292
301
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
293
302
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -310,8 +319,8 @@ void main() {
310
319
311
320
// Manually partial unroll
312
321
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
313
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
322
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
323
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
315
324
316
325
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
317
326
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -325,8 +334,8 @@ void main() {
325
334
store_scales(tid);
326
335
}
327
336
while (block_k < end_k) {
328
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
329
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
337
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
338
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
330
339
331
340
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
332
341
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -350,8 +359,8 @@ void main() {
350
359
351
360
// Manually partial unroll
352
361
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
353
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
354
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
362
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
363
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
355
364
356
365
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
357
366
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -365,8 +374,8 @@ void main() {
365
374
store_scales(tid);
366
375
}
367
376
while (block_k < end_k) {
368
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
369
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
377
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
378
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
370
379
371
380
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
372
381
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -405,8 +414,8 @@ void main() {
405
414
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
406
415
}
407
416
408
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
409
- coopmat<FLOAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
417
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418
+ coopmat<MAT_TYPE , gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
410
419
411
420
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
412
421
#ifdef MUL_MAT_ID
0 commit comments