@@ -484,6 +484,7 @@ struct vk_device_struct {
484484    vk_pipeline pipeline_rwkv_wkv7_f32;
485485    vk_pipeline pipeline_opt_step_adamw_f32;
486486    vk_pipeline pipeline_conv2d_f32;
487+     vk_pipeline pipeline_conv2d_f16_f32;
487488    vk_pipeline pipeline_conv2d_dw_whcn_f32;
488489    vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489490
@@ -3074,12 +3075,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
30743075            device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30753076            sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30763077            { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3078+         ggml_vk_create_pipeline(
3079+             device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3080+             sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081+             { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
30773082    } else {
30783083        ggml_vk_create_pipeline(
30793084            device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30803085            sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30813086            { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
30823087            false);
3088+         ggml_vk_create_pipeline(
3089+             device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3090+             sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3091+             { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3092+             false);
30833093    }
30843094
30853095    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6958,9 +6968,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69586968        }
69596969        return nullptr;
69606970    case GGML_OP_CONV_2D:
6961-         if (src0->type == GGML_TYPE_F32 &&  src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6971+         if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
69626972            ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963-             return ctx->device->pipeline_conv2d_f32;
6973+             if (src0->type == GGML_TYPE_F32) {
6974+                 return ctx->device->pipeline_conv2d_f32;
6975+             } else if (src0->type == GGML_TYPE_F16) {
6976+                 return ctx->device->pipeline_conv2d_f16_f32;
6977+             }
69646978        }
69656979        return nullptr;
69666980    case GGML_OP_CONV_2D_DW:
@@ -7882,6 +7896,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
78827896    const uint32_t src1_type_size = ggml_type_size(src1->type);
78837897    const uint32_t dst_type_size = ggml_type_size(dst->type);
78847898
7899+     // Skip empty skip_rows operations. For most ops the empty check at the start
7900+     // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
7901+     // with empty srcs.
7902+     if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
7903+         return;
7904+     }
7905+ 
78857906    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
78867907        (uint32_t)ggml_nelements(src0),
78877908        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
@@ -8178,13 +8199,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81788199
81798200static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
81808201                            const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181-     GGML_ASSERT(src0->type == GGML_TYPE_F32);
8202+     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 );
81828203    GGML_ASSERT(src1->type == GGML_TYPE_F32);
81838204    GGML_ASSERT(dst->type == GGML_TYPE_F32);
81848205
81858206    GGML_TENSOR_BINARY_OP_LOCALS
81868207
8187-     GGML_ASSERT(nb00 == sizeof(float));
8208+     GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t) );
81888209    GGML_ASSERT(nb10 == sizeof(float));
81898210    GGML_ASSERT(nb0 == sizeof(float));
81908211
@@ -10867,7 +10888,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1086710888                const vk_device& device = ggml_vk_get_device(ctx->device);
1086810889                bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
1086910890                // Channel-contiguous format is not supported yet.
10870-                 return (op->src[0]->type == GGML_TYPE_F32 &&
10891+                 return (( op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)  &&
1087110892                    op->src[1]->type == GGML_TYPE_F32 &&
1087210893                    op->type == GGML_TYPE_F32 &&
1087310894                    ggml_is_contiguous(op->src[0]) &&
0 commit comments