@@ -389,6 +389,8 @@ struct vk_device_struct {
389389 vk_pipeline pipeline_rwkv_wkv6_f32;
390390 vk_pipeline pipeline_rwkv_wkv7_f32;
391391 vk_pipeline pipeline_opt_step_adamw_f32;
392+ vk_pipeline pipeline_conv2d_dw_whcn_f32;
393+ vk_pipeline pipeline_conv2d_dw_cwhn_f32;
392394
393395 // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
394396 vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -701,6 +703,24 @@ struct vk_op_rwkv_wkv7_push_constants {
701703 uint32_t H;
702704};
703705
706+ struct vk_op_conv2d_dw_push_constants {
707+ uint32_t ne;
708+ uint32_t batches;
709+ uint32_t channels;
710+ uint32_t dst_w;
711+ uint32_t dst_h;
712+ uint32_t src_w;
713+ uint32_t src_h;
714+ uint32_t knl_w;
715+ uint32_t knl_h;
716+ int32_t stride_x;
717+ int32_t stride_y;
718+ int32_t pad_x;
719+ int32_t pad_y;
720+ int32_t dilation_x;
721+ int32_t dilation_y;
722+ };
723+
704724struct vk_op_upscale_push_constants {
705725 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
706726 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2610,6 +2630,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
26102630
26112631 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
26122632
2633+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2634+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2635+
26132636 for (auto &c : compiles) {
26142637 c.wait();
26152638 }
@@ -6137,6 +6160,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
61376160 return ctx->device->pipeline_leaky_relu_f32;
61386161 }
61396162 return nullptr;
6163+ case GGML_OP_CONV_2D_DW:
6164+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6165+ if (ggml_is_contiguous(src1)) {
6166+ return ctx->device->pipeline_conv2d_dw_whcn_f32;
6167+ } else if (ggml_is_contiguous_channels(src1)) {
6168+ return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6169+ }
6170+ }
6171+ return nullptr;
61406172 default:
61416173 return nullptr;
61426174 }
@@ -6163,6 +6195,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
61636195 case GGML_OP_REPEAT_BACK:
61646196 case GGML_OP_ROPE:
61656197 case GGML_OP_RMS_NORM:
6198+ case GGML_OP_CONV_2D_DW:
61666199 return true;
61676200 default:
61686201 return false;
@@ -6459,6 +6492,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
64596492 case GGML_OP_CONCAT:
64606493 case GGML_OP_UPSCALE:
64616494 case GGML_OP_UNARY:
6495+ case GGML_OP_CONV_2D_DW:
64626496 {
64636497 const uint32_t ne = ggml_nelements(dst);
64646498 if (ne > 262144) {
@@ -7245,6 +7279,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
72457279 }, dryrun);
72467280}
72477281
7282+ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7283+ vk_op_conv2d_dw_push_constants p{};
7284+ p.ne = ggml_nelements(dst);
7285+ p.channels = dst->ne[2];
7286+ p.batches = dst->ne[3];
7287+ p.dst_w = dst->ne[0];
7288+ p.dst_h = dst->ne[1];
7289+ p.src_w = src1->ne[0];
7290+ p.src_h = src1->ne[1];
7291+ p.knl_w = src0->ne[0];
7292+ p.knl_h = src0->ne[1];
7293+ p.stride_x = dst->op_params[0];
7294+ p.stride_y = dst->op_params[1];
7295+ p.pad_x = dst->op_params[2];
7296+ p.pad_y = dst->op_params[3];
7297+ p.dilation_x = dst->op_params[4];
7298+ p.dilation_y = dst->op_params[5];
7299+
7300+ GGML_ASSERT(src0->ne[3] == p.channels);
7301+ GGML_ASSERT(src1->ne[3] == p.batches);
7302+
7303+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7304+ }
7305+
72487306static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
72497307 const float * op_params = (const float *)dst->op_params;
72507308 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -8265,6 +8323,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
82658323 case GGML_OP_IM2COL:
82668324 case GGML_OP_TIMESTEP_EMBEDDING:
82678325 case GGML_OP_POOL_2D:
8326+ case GGML_OP_CONV_2D_DW:
82688327 case GGML_OP_RWKV_WKV6:
82698328 case GGML_OP_RWKV_WKV7:
82708329 case GGML_OP_LEAKY_RELU:
@@ -8328,6 +8387,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83288387 case GGML_OP_IM2COL:
83298388 case GGML_OP_TIMESTEP_EMBEDDING:
83308389 case GGML_OP_POOL_2D:
8390+ case GGML_OP_CONV_2D_DW:
83318391 case GGML_OP_LEAKY_RELU:
83328392 {
83338393 // These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -8501,6 +8561,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85018561 case GGML_OP_POOL_2D:
85028562 ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
85038563
8564+ break;
8565+ case GGML_OP_CONV_2D_DW:
8566+ ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8567+
85048568 break;
85058569 case GGML_OP_LEAKY_RELU:
85068570 ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -8622,6 +8686,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
86228686 case GGML_OP_IM2COL:
86238687 case GGML_OP_TIMESTEP_EMBEDDING:
86248688 case GGML_OP_POOL_2D:
8689+ case GGML_OP_CONV_2D_DW:
86258690 case GGML_OP_RWKV_WKV6:
86268691 case GGML_OP_RWKV_WKV7:
86278692 case GGML_OP_LEAKY_RELU:
@@ -9599,6 +9664,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
95999664 case GGML_OP_COUNT_EQUAL:
96009665 case GGML_OP_IM2COL:
96019666 case GGML_OP_TIMESTEP_EMBEDDING:
9667+ case GGML_OP_CONV_2D_DW:
96029668 case GGML_OP_POOL_2D:
96039669 case GGML_OP_RWKV_WKV6:
96049670 case GGML_OP_RWKV_WKV7:
0 commit comments