Skip to content

Commit 60b5d31

Browse files
committed
vulkan: Support bf16 tensors without the bf16 extension or coopmat support
Compile a variant of the scalar mul_mm shader that will promote the bf16 values to float, and use that when either the bf16 extension or the coopmat extensions aren't available.
1 parent 5f1bb0f commit 60b5d31

File tree

3 files changed

+52
-35
lines changed

3 files changed

+52
-35
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,8 +1924,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
19241924
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19251925
if (device->coopmat_bf16_support) {
19261926
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1927-
}
1927+
} else
19281928
#endif
1929+
{
1930+
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1931+
}
19291932
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19301933
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
19311934
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1950,8 +1953,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
19501953
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19511954
if (device->coopmat_bf16_support) {
19521955
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1953-
}
1956+
} else
19541957
#endif
1958+
{
1959+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4);
1960+
}
19551961
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19561962
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
19571963
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -2008,8 +2014,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
20082014
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20092015
if (device->coopmat_bf16_support) {
20102016
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2011-
}
2017+
} else
20122018
#endif
2019+
{
2020+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2021+
}
20132022

20142023
if (device->coopmat_acc_f16_support) {
20152024
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2061,8 +2070,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
20612070
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20622071
if (device->coopmat_bf16_support) {
20632072
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2064-
}
2073+
} else
20652074
#endif
2075+
{
2076+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2077+
}
20662078

20672079
if (device->coopmat_acc_f16_support) {
20682080
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2145,6 +2157,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21452157
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21462158
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21472159

2160+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2161+
21482162
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21492163
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21502164
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2180,6 +2194,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21802194
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21812195
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21822196

2197+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2198+
21832199
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
21842200
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
21852201
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2232,6 +2248,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22322248
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
22332249
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
22342250

2251+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2252+
22352253
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
22362254
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
22372255
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2267,6 +2285,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22672285
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
22682286
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
22692287

2288+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2289+
22702290
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22712291
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
22722292
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -5119,11 +5139,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
51195139
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
51205140
}
51215141

5122-
static bool ggml_vk_can_use_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst) {
5123-
return (dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5124-
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type));
5125-
}
5126-
51275142
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
51285143
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
51295144
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
@@ -5142,7 +5157,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
51425157
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
51435158
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
51445159
// when ne12 and ne13 are one.
5145-
} else if (ggml_vk_can_use_mul_mat_vec(src0, src1, dst)) {
5160+
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5161+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
51465162
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
51475163
} else {
51485164
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -9344,6 +9360,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
93449360
switch (src0_type) {
93459361
case GGML_TYPE_F32:
93469362
case GGML_TYPE_F16:
9363+
case GGML_TYPE_BF16:
93479364
case GGML_TYPE_Q4_0:
93489365
case GGML_TYPE_Q4_1:
93499366
case GGML_TYPE_Q5_0:
@@ -9364,17 +9381,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
93649381
case GGML_TYPE_IQ4_XS:
93659382
case GGML_TYPE_IQ4_NL:
93669383
break;
9367-
case GGML_TYPE_BF16:
9368-
if (!device->coopmat_bf16_support) {
9369-
if (op->op == GGML_OP_MUL_MAT_ID) {
9370-
return false;
9371-
}
9372-
// mul_mat_vec expands to float and doesn't require bf16 hardware support
9373-
if (!ggml_vk_can_use_mul_mat_vec(op->src[0], op->src[1], op)) {
9374-
return false;
9375-
}
9376-
}
9377-
break;
93789384
default:
93799385
return false;
93809386
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
#define LOAD_VEC_B 1
3434
#endif
3535

36+
#if !defined(TO_FLOAT_TYPE)
37+
#define TO_FLOAT_TYPE FLOAT_TYPE
38+
#endif
39+
3640
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3741

3842
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -256,15 +260,15 @@ void main() {
256260
#if LOAD_VEC_A == 4
257261
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
258262
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
259-
buf_a[buf_idx ] = uintBitsToBFloat16EXT(data_a[idx].x);
260-
buf_a[buf_idx + 1] = uintBitsToBFloat16EXT(data_a[idx].y);
261-
buf_a[buf_idx + 2] = uintBitsToBFloat16EXT(data_a[idx].z);
262-
buf_a[buf_idx + 3] = uintBitsToBFloat16EXT(data_a[idx].w);
263+
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
264+
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
265+
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
266+
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
263267
#else
264268
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
265-
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
269+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
266270
} else {
267-
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = uintBitsToBFloat16EXT(uint16_t(0));
271+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
268272
}
269273
#endif
270274
#elif defined(DATA_A_Q4_0)
@@ -714,21 +718,21 @@ void main() {
714718
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
715719
#endif
716720
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
717-
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
718-
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
719-
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
720-
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
721+
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
722+
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
723+
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
724+
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
721725
#elif !MUL_MAT_ID
722726
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
723-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
727+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
724728
} else {
725729
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
726730
}
727731
#else
728732
const uint row_i = ic * BN + loadc_b + l;
729733
if (row_i < _ne1) {
730734
const u16vec2 row_idx = row_ids[row_i];
731-
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
735+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
732736
} else {
733737
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
734738
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
321321

322322
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
323323
if (t == "bf16") {
324+
// scalar path promotes to float
325+
if (!coopmat && !coopmat2) {
326+
return "float";
327+
}
324328
return "bfloat16_t";
325329
}
326330
if (coopmat2 || fp16) {
@@ -343,8 +347,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
343347
// For aligned matmul loads
344348
std::string load_vec_a = coopmat2 ? "1" : "4";
345349

346-
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
347-
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "bfloat16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
350+
// scalar path promotes to float
351+
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
352+
353+
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "uint16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
354+
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
348355
}
349356
#endif
350357

0 commit comments

Comments
 (0)