@@ -484,6 +484,7 @@ struct vk_device_struct {
484484 vk_pipeline pipeline_rwkv_wkv7_f32;
485485 vk_pipeline pipeline_opt_step_adamw_f32;
486486 vk_pipeline pipeline_conv2d_f32;
487+ vk_pipeline pipeline_conv2d_f16_f32;
487488 vk_pipeline pipeline_conv2d_dw_whcn_f32;
488489 vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489490
@@ -3074,12 +3075,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
30743075 device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30753076 sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30763077 { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3078+ ggml_vk_create_pipeline(
3079+ device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3080+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
30773082 } else {
30783083 ggml_vk_create_pipeline(
30793084 device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30803085 sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30813086 { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
30823087 false);
3088+ ggml_vk_create_pipeline(
3089+ device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3090+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3091+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3092+ false);
30833093 }
30843094
30853095 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6958,9 +6968,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69586968 }
69596969 return nullptr;
69606970 case GGML_OP_CONV_2D:
6961- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6971+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
69626972 ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963- return ctx->device->pipeline_conv2d_f32;
6973+ if (src0->type == GGML_TYPE_F32) {
6974+ return ctx->device->pipeline_conv2d_f32;
6975+ } else if (src0->type == GGML_TYPE_F16) {
6976+ return ctx->device->pipeline_conv2d_f16_f32;
6977+ }
69646978 }
69656979 return nullptr;
69666980 case GGML_OP_CONV_2D_DW:
@@ -8178,13 +8192,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81788192
81798193static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
81808194 const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181- GGML_ASSERT(src0->type == GGML_TYPE_F32);
8195+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 );
81828196 GGML_ASSERT(src1->type == GGML_TYPE_F32);
81838197 GGML_ASSERT(dst->type == GGML_TYPE_F32);
81848198
81858199 GGML_TENSOR_BINARY_OP_LOCALS
81868200
8187- GGML_ASSERT(nb00 == sizeof(float));
8201+ GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t) );
81888202 GGML_ASSERT(nb10 == sizeof(float));
81898203 GGML_ASSERT(nb0 == sizeof(float));
81908204
0 commit comments