Skip to content

Commit 21e8793

Browse files
committed
vulkan: Support bf16 tensors without the bf16 extension or coopmat support
Compile a variant of the scalar mul_mm shader that will promote the bf16 values to float, and use that when either the bf16 extension or the coopmat extensions aren't available.
1 parent 9cd10f8 commit 21e8793

File tree

3 files changed

+52
-35
lines changed

3 files changed

+52
-35
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,8 +1870,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
18701870
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
18711871
if (device->coopmat_bf16_support) {
18721872
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1873-
}
1873+
} else
18741874
#endif
1875+
{
1876+
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1877+
}
18751878
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18761879
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
18771880
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1896,8 +1899,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
18961899
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
18971900
if (device->coopmat_bf16_support) {
18981901
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1899-
}
1902+
} else
19001903
#endif
1904+
{
1905+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
1906+
}
19011907
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19021908
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19031909
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1954,8 +1960,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
19541960
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19551961
if (device->coopmat_bf16_support) {
19561962
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
1957-
}
1963+
} else
19581964
#endif
1965+
{
1966+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1967+
}
19591968

19601969
if (device->coopmat_acc_f16_support) {
19611970
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2007,8 +2016,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
20072016
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20082017
if (device->coopmat_bf16_support) {
20092018
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2010-
}
2019+
} else
20112020
#endif
2021+
{
2022+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2023+
}
20122024

20132025
if (device->coopmat_acc_f16_support) {
20142026
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2091,6 +2103,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20912103
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
20922104
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
20932105

2106+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2107+
20942108
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
20952109
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
20962110
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2126,6 +2140,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21262140
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21272141
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21282142

2143+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2144+
21292145
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
21302146
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
21312147
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2178,6 +2194,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21782194
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21792195
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21802196

2197+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2198+
21812199
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21822200
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21832201
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2213,6 +2231,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22132231
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
22142232
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
22152233

2234+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2235+
22162236
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22172237
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22182238
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -5057,11 +5077,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
50575077
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
50585078
}
50595079

5060-
static bool ggml_vk_can_use_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst) {
5061-
return (dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5062-
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type));
5063-
}
5064-
50655080
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
50665081
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
50675082
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
@@ -5080,7 +5095,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
50805095
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
50815096
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
50825097
// when ne12 and ne13 are one.
5083-
} else if (ggml_vk_can_use_mul_mat_vec(src0, src1, dst)) {
5098+
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5099+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
50845100
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
50855101
} else {
50865102
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -9187,6 +9203,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
91879203
switch (src0_type) {
91889204
case GGML_TYPE_F32:
91899205
case GGML_TYPE_F16:
9206+
case GGML_TYPE_BF16:
91909207
case GGML_TYPE_Q4_0:
91919208
case GGML_TYPE_Q4_1:
91929209
case GGML_TYPE_Q5_0:
@@ -9207,17 +9224,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
92079224
case GGML_TYPE_IQ4_XS:
92089225
case GGML_TYPE_IQ4_NL:
92099226
break;
9210-
case GGML_TYPE_BF16:
9211-
if (!device->coopmat_bf16_support) {
9212-
if (op->op == GGML_OP_MUL_MAT_ID) {
9213-
return false;
9214-
}
9215-
// mul_mat_vec expands to float and doesn't require bf16 hardware support
9216-
if (!ggml_vk_can_use_mul_mat_vec(op->src[0], op->src[1], op)) {
9217-
return false;
9218-
}
9219-
}
9220-
break;
92219227
default:
92229228
return false;
92239229
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
#define LOAD_VEC_B 1
3434
#endif
3535

36+
#if !defined(TO_FLOAT_TYPE)
37+
#define TO_FLOAT_TYPE FLOAT_TYPE
38+
#endif
39+
3640
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3741

3842
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -256,15 +260,15 @@ void main() {
256260
#if LOAD_VEC_A == 4
257261
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
258262
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
259-
buf_a[buf_idx ] = uintBitsToBFloat16EXT(data_a[idx].x);
260-
buf_a[buf_idx + 1] = uintBitsToBFloat16EXT(data_a[idx].y);
261-
buf_a[buf_idx + 2] = uintBitsToBFloat16EXT(data_a[idx].z);
262-
buf_a[buf_idx + 3] = uintBitsToBFloat16EXT(data_a[idx].w);
263+
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
264+
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
265+
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
266+
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
263267
#else
264268
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
265-
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
269+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
266270
} else {
267-
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(uint16_t(0));
271+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
268272
}
269273
#endif
270274
#elif defined(DATA_A_Q4_0)
@@ -714,21 +718,21 @@ void main() {
714718
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
715719
#endif
716720
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
717-
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
718-
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
719-
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
720-
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
721+
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
722+
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
723+
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
724+
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
721725
#elif !MUL_MAT_ID
722726
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
723-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
727+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
724728
} else {
725729
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
726730
}
727731
#else
728732
const uint row_i = ic * BN + loadc_b + l;
729733
if (row_i < _ne1) {
730734
const u16vec2 row_idx = row_ids[row_i];
731-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
735+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
732736
} else {
733737
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
734738
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
321321

322322
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
323323
if (t == "bf16") {
324+
// scalar path promotes to float
325+
if (!coopmat && !coopmat2) {
326+
return "float";
327+
}
324328
return "bfloat16_t";
325329
}
326330
if (coopmat2 || fp16) {
@@ -343,8 +347,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
343347
// For aligned matmul loads
344348
std::string load_vec_a = coopmat2 ? "1" : "4";
345349

346-
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
347-
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "bfloat16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
350+
// scalar path promotes to float
351+
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
352+
353+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
354+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
348355
}
349356
#endif
350357

0 commit comments

Comments
 (0)