Skip to content

Commit 79f26e9

Browse files
authored
vulkan: Add bfloat16 support (ggml-org#12554)
* vulkan: Add bfloat16 support This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16. The extension is required for coopmat multiply support, but matrix-vector multiply trivially promotes bf16 to fp32 and doesn't require the extension. The copy/get_rows shaders also don't require the extension. It's probably possible to fall back to non-coopmat and promote to fp32 when the extension isn't supported, but this change doesn't do that. The coopmat support also requires a glslc that supports the extension, which currently requires a custom build. * vulkan: Support bf16 tensors without the bf16 extension or coopmat support Compile a variant of the scalar mul_mm shader that will promote the bf16 values to float, and use that when either the bf16 extension or the coopmat extensions aren't available. * vulkan: bfloat16 fixes (really works without bfloat16 support now) * vulkan: fix spirv-val failure and reenable -O
1 parent fc727bc commit 79f26e9

13 files changed

+366
-65
lines changed

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ if (Vulkan_FOUND)
7171
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
7272
endif()
7373

74+
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
75+
# If it's not, there will be an error to stderr.
76+
# If it's supported, set a define to indicate that we should compile those shaders
77+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
78+
OUTPUT_VARIABLE glslc_output
79+
ERROR_VARIABLE glslc_error)
80+
81+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
82+
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
83+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
84+
else()
85+
message(STATUS "GL_EXT_bfloat16 supported by glslc")
86+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
87+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
88+
endif()
89+
7490
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
7591
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
7692

@@ -142,6 +158,7 @@ if (Vulkan_FOUND)
142158
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
143159
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
144160
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
161+
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
145162
BUILD_COMMAND ${CMAKE_COMMAND} --build .
146163
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
147164
INSTALL_DIR ${CMAKE_BINARY_DIR}

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 175 additions & 21 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ endif()
1212
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
1313
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
1414
endif()
15+
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
16+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
17+
endif()
1518
set(TARGET vulkan-shaders-gen)
1619
add_executable(${TARGET} vulkan-shaders-gen.cpp)
1720
install(TARGETS ${TARGET} RUNTIME)

ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ void main() {
1818
// fast path for when all four iterations are in-bounds
1919
if (idx + (num_iter-1)*num_threads < p.ne) {
2020
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
21-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
21+
22+
#if defined(DATA_D_BF16)
23+
float f = float(data_a[get_aoffset() + idx]);
24+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
25+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
2226
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
2327
#else
2428
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@@ -31,7 +35,10 @@ void main() {
3135
continue;
3236
}
3337

34-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
38+
#if defined(DATA_D_BF16)
39+
float f = float(data_a[get_aoffset() + idx]);
40+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
41+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
3542
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
3643
#else
3744
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];

ggml/src/ggml-vulkan/vulkan-shaders/copy.comp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ void main() {
1212
return;
1313
}
1414

15-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
15+
#if defined(DATA_D_BF16)
16+
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
17+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
18+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
1619
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
1720
#else
1821
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2323
}
2424
#endif
2525

26+
#if defined(DATA_A_BF16)
27+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
28+
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
29+
}
30+
#endif
31+
2632
#if defined(DATA_A_Q4_0)
2733
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2834
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
428434
}
429435
#endif
430436

431-
#if defined(DATA_A_F32) || defined(DATA_A_F16)
437+
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
432438
vec2 get_dm(uint ib, uint a_offset) {
433439
return vec2(0, 0);
434440
}

ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ void main() {
2020
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
2121
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
2222

23+
#if defined(DATA_A_BF16)
24+
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
25+
#else
26+
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
27+
#endif
2328
#ifndef OPTIMIZATION_ERROR_WORKAROUND
24-
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
29+
data_d[d_offset + i00] = D_TYPE(v);
2530
#else
26-
data_d[d_offset + i00] = data_a[a_offset + i00];
31+
data_d[d_offset + i00] = D_TYPE(v);
2732
#endif
2833
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

9-
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
9+
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
1010
#define K_PER_ITER 8
1111
#else
1212
#define K_PER_ITER 2

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13+
#if defined(DATA_A_BF16) && defined(COOPMAT)
14+
#extension GL_EXT_bfloat16 : enable
15+
#endif
16+
1317
#ifdef COOPMAT
1418
#extension GL_KHR_cooperative_matrix : enable
1519
#extension GL_KHR_memory_scope_semantics : enable
@@ -29,6 +33,10 @@
2933
#define LOAD_VEC_B 1
3034
#endif
3135

36+
#if !defined(TO_FLOAT_TYPE)
37+
#define TO_FLOAT_TYPE FLOAT_TYPE
38+
#endif
39+
3240
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3341

3442
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -202,8 +210,8 @@ void main() {
202210
#endif
203211

204212
#ifdef COOPMAT
205-
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
206-
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
213+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
214+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
207215
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
208216

209217
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@@ -248,6 +256,21 @@ void main() {
248256
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
249257
}
250258
#endif
259+
#elif defined(DATA_A_BF16)
260+
#if LOAD_VEC_A == 4
261+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
262+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
263+
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
264+
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
265+
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
266+
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
267+
#else
268+
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
269+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
270+
} else {
271+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
272+
}
273+
#endif
251274
#elif defined(DATA_A_Q4_0)
252275
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
253276
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
@@ -695,21 +718,21 @@ void main() {
695718
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
696719
#endif
697720
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
698-
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
699-
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
700-
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
701-
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
721+
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
722+
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
723+
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
724+
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
702725
#elif !MUL_MAT_ID
703726
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
704-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
727+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
705728
} else {
706729
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
707730
}
708731
#else
709732
const uint row_i = ic * BN + loadc_b + l;
710733
if (row_i < _ne1) {
711734
const u16vec2 row_idx = row_ids[row_i];
712-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
735+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
713736
} else {
714737
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
715738
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#extension GL_EXT_buffer_reference : enable
1515
#extension GL_KHR_shader_subgroup_ballot : enable
1616
#extension GL_KHR_shader_subgroup_vote : enable
17+
#ifdef DATA_A_BF16
18+
#extension GL_EXT_bfloat16 : enable
19+
#endif
1720

1821
#include "types.comp"
1922

@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
8083
#define store_scales(a)
8184
#endif
8285

86+
#if defined(DATA_A_BF16)
87+
#define MAT_TYPE bfloat16_t
88+
#else
89+
#define MAT_TYPE FLOAT_TYPE
90+
#endif
91+
8392
#ifdef MUL_MAT_ID
8493
layout (binding = 3) readonly buffer IDS {int data_ids[];};
8594

@@ -271,8 +280,8 @@ void main() {
271280

272281
// Manually partial unroll
273282
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
274-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
283+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
276285

277286
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278287
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -286,8 +295,8 @@ void main() {
286295
store_scales(tid);
287296
}
288297
while (block_k < end_k) {
289-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
290-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
298+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
299+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
291300

292301
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
293302
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -310,8 +319,8 @@ void main() {
310319

311320
// Manually partial unroll
312321
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
313-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
322+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
323+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
315324

316325
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
317326
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -325,8 +334,8 @@ void main() {
325334
store_scales(tid);
326335
}
327336
while (block_k < end_k) {
328-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
329-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
337+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
338+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
330339

331340
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
332341
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -350,8 +359,8 @@ void main() {
350359

351360
// Manually partial unroll
352361
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
353-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
354-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
362+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
363+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
355364

356365
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
357366
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -365,8 +374,8 @@ void main() {
365374
store_scales(tid);
366375
}
367376
while (block_k < end_k) {
368-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
369-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
377+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
378+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
370379

371380
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
372381
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -405,8 +414,8 @@ void main() {
405414
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
406415
}
407416

408-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
409-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
417+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
410419

411420
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
412421
#ifdef MUL_MAT_ID

0 commit comments

Comments
 (0)