Skip to content

Commit 97ae596

Browse files
authored
vulkan : support conv_2d_dw with f16 weights (ggml-org#15392)
1 parent 20c2dac commit 97ae596

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ struct vk_device_struct {
530530
vk_pipeline pipeline_opt_step_sgd_f32;
531531
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
532532
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
533-
vk_pipeline pipeline_conv2d_dw_whcn_f32;
534-
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
533+
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
534+
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
535535

536536
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
537537
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
@@ -3257,6 +3257,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
32573257

32583258
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
32593259
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3260+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3261+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
32603262

32613263
for (auto &c : compiles) {
32623264
c.wait();
@@ -7346,6 +7348,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
73467348
} else if (ggml_is_contiguous_channels(src1)) {
73477349
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
73487350
}
7351+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
7352+
if (ggml_is_contiguous(src1)) {
7353+
return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
7354+
} else if (ggml_is_contiguous_channels(src1)) {
7355+
return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
7356+
}
73497357
}
73507358
return nullptr;
73517359
default:

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@ void process_shaders() {
680680

681681
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
682682
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
683+
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
684+
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
683685

684686
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
685687

0 commit comments

Comments
 (0)