Skip to content

Commit 9cd10f8

Browse files
committed
vulkan: Add bfloat16 support
This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16. The extension is required for coopmat multiply support, but matrix-vector multiply trivially promotes bf16 to fp32 and doesn't require the extension. The copy/get_rows shaders also don't require the extension. It's probably possible to fall back to non-coopmat and promote to fp32 when the extension isn't supported, but this change doesn't do that. The coopmat support also requires a glslc that supports the extension, which currently requires a custom build.
1 parent f423981 commit 9cd10f8

File tree

12 files changed

+351
-64
lines changed

12 files changed

+351
-64
lines changed

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,52 @@ if (Vulkan_FOUND)
6969
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
7070
endif()
7171

72-
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
73-
# If it's not, there will be an error to stderr.
74-
# If it's supported, set a define to indicate that we should compile those shaders
75-
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
76-
OUTPUT_VARIABLE glslc_output
77-
ERROR_VARIABLE glslc_error)
78-
79-
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
80-
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
72+
if(NOT DEFINED GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
73+
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
74+
# If it's not, there will be an error to stderr.
75+
# If it's supported, set a define to indicate that we should compile those shaders
76+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
77+
OUTPUT_VARIABLE glslc_output
78+
ERROR_VARIABLE glslc_error)
79+
80+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
81+
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
82+
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether integer dot is supported by glslc")
83+
else()
84+
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
85+
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON CACHE INTERNAL "Whether integer dot is supported by glslc")
86+
endif()
8187
else()
82-
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
88+
message(STATUS "GL_EXT_integer_dot_product support already defined: ${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}")
89+
endif()
90+
91+
if(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
8392
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
8493
endif()
8594

95+
if(NOT DEFINED GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
96+
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
97+
# If it's not, there will be an error to stderr.
98+
# If it's supported, set a define to indicate that we should compile those shaders
99+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
100+
OUTPUT_VARIABLE glslc_output
101+
ERROR_VARIABLE glslc_error)
102+
103+
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
104+
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
105+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether bfloat16 is supported by glslc")
106+
else()
107+
message(STATUS "GL_EXT_bfloat16 supported by glslc")
108+
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON CACHE INTERNAL "Whether bfloat16 is supported by glslc")
109+
endif()
110+
else()
111+
message(STATUS "GL_EXT_bfloat16 support already defined: ${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}")
112+
endif()
113+
114+
if(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
115+
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
116+
endif()
117+
86118
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
87119
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
88120

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 163 additions & 21 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ void main() {
1818
// fast path for when all four iterations are in-bounds
1919
if (idx + (num_iter-1)*num_threads < p.ne) {
2020
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
21-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
21+
22+
#if defined(DATA_D_BF16)
23+
float f = float(data_a[get_aoffset() + idx]);
24+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
25+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
2226
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
2327
#else
2428
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@@ -31,7 +35,10 @@ void main() {
3135
continue;
3236
}
3337

34-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
38+
#if defined(DATA_D_BF16)
39+
float f = float(data_a[get_aoffset() + idx]);
40+
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
41+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
3542
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
3643
#else
3744
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];

ggml/src/ggml-vulkan/vulkan-shaders/copy.comp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ void main() {
1212
return;
1313
}
1414

15-
#ifndef OPTIMIZATION_ERROR_WORKAROUND
15+
#if defined(DATA_D_BF16)
16+
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
17+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
18+
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
1619
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
1720
#else
1821
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2323
}
2424
#endif
2525

26+
#if defined(DATA_A_BF16)
27+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
28+
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
29+
}
30+
#endif
31+
2632
#if defined(DATA_A_Q4_0)
2733
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2834
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
428434
}
429435
#endif
430436

431-
#if defined(DATA_A_F32) || defined(DATA_A_F16)
437+
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
432438
vec2 get_dm(uint ib, uint a_offset) {
433439
return vec2(0, 0);
434440
}

ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ void main() {
2020
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
2121
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
2222

23+
#if defined(DATA_A_BF16)
24+
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
25+
#else
26+
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
27+
#endif
2328
#ifndef OPTIMIZATION_ERROR_WORKAROUND
24-
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
29+
data_d[d_offset + i00] = D_TYPE(v);
2530
#else
26-
data_d[d_offset + i00] = data_a[a_offset + i00];
31+
data_d[d_offset + i00] = D_TYPE(v);
2732
#endif
2833
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

9-
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
9+
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
1010
#define K_PER_ITER 8
1111
#else
1212
#define K_PER_ITER 2

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13+
#ifdef DATA_A_BF16
14+
#extension GL_EXT_bfloat16 : enable
15+
#endif
16+
1317
#ifdef COOPMAT
1418
#extension GL_KHR_cooperative_matrix : enable
1519
#extension GL_KHR_memory_scope_semantics : enable
@@ -202,8 +206,8 @@ void main() {
202206
#endif
203207

204208
#ifdef COOPMAT
205-
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
206-
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
209+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
210+
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
207211
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
208212

209213
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@@ -248,6 +252,21 @@ void main() {
248252
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
249253
}
250254
#endif
255+
#elif defined(DATA_A_BF16)
256+
#if LOAD_VEC_A == 4
257+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
258+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
259+
buf_a[buf_idx ] = uintBitsToBFloat16EXT(data_a[idx].x);
260+
buf_a[buf_idx + 1] = uintBitsToBFloat16EXT(data_a[idx].y);
261+
buf_a[buf_idx + 2] = uintBitsToBFloat16EXT(data_a[idx].z);
262+
buf_a[buf_idx + 3] = uintBitsToBFloat16EXT(data_a[idx].w);
263+
#else
264+
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
265+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
266+
} else {
267+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(uint16_t(0));
268+
}
269+
#endif
251270
#elif defined(DATA_A_Q4_0)
252271
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
253272
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#extension GL_EXT_buffer_reference : enable
1515
#extension GL_KHR_shader_subgroup_ballot : enable
1616
#extension GL_KHR_shader_subgroup_vote : enable
17+
#ifdef DATA_A_BF16
18+
#extension GL_EXT_bfloat16 : enable
19+
#endif
1720

1821
#include "types.comp"
1922

@@ -70,6 +73,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
7073
#define DECODEFUNCA
7174
#endif
7275

76+
#if defined(DATA_A_BF16)
77+
#define MAT_TYPE bfloat16_t
78+
#else
79+
#define MAT_TYPE FLOAT_TYPE
80+
#endif
81+
7382
#ifdef MUL_MAT_ID
7483
layout (binding = 3) readonly buffer IDS {int data_ids[];};
7584

@@ -239,8 +248,8 @@ void main() {
239248
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
240249
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
241250

242-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
243-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
251+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
252+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
244253

245254
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
246255
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -255,8 +264,8 @@ void main() {
255264
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
256265
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
257266

258-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
259-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
267+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
268+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
260269

261270
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
262271
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -271,8 +280,8 @@ void main() {
271280
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
272281
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
273282

274-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
275-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
283+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
284+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
276285

277286
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
278287
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -301,8 +310,8 @@ void main() {
301310
[[dont_unroll]]
302311
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
303312

304-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
305-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
313+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
314+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
306315

307316
// Clamping is expensive, so detect different code paths for each combination
308317
// of A and B needing clamping.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#version 460
2+
3+
#extension GL_EXT_bfloat16 : require
4+
5+
void main()
6+
{
7+
}

0 commit comments

Comments
 (0)