@@ -368,6 +368,8 @@ struct vk_device_struct {
368368 vk_pipeline pipeline_rwkv_wkv6_f32;
369369 vk_pipeline pipeline_rwkv_wkv7_f32;
370370 vk_pipeline pipeline_opt_step_adamw_f32;
371+ vk_pipeline pipeline_conv2d_dw_whcn_f32;
372+ vk_pipeline pipeline_conv2d_dw_cwhn_f32;
371373
372374 // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
373375 vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -680,6 +682,24 @@ struct vk_op_rwkv_wkv7_push_constants {
680682 uint32_t H;
681683};
682684
685+ struct vk_op_conv2d_dw_push_constants {
686+ uint32_t ne;
687+ uint32_t batches;
688+ uint32_t channels;
689+ uint32_t dst_w;
690+ uint32_t dst_h;
691+ uint32_t src_w;
692+ uint32_t src_h;
693+ uint32_t knl_w;
694+ uint32_t knl_h;
695+ int32_t stride_x;
696+ int32_t stride_y;
697+ int32_t pad_x;
698+ int32_t pad_y;
699+ int32_t dilation_x;
700+ int32_t dilation_y;
701+ };
702+
683703struct vk_op_upscale_push_constants {
684704 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
685705 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2529,6 +2549,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
25292549
25302550 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
25312551
2552+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2553+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2554+
25322555 for (auto &c : compiles) {
25332556 c.wait();
25342557 }
@@ -5988,6 +6011,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
59886011 return ctx->device->pipeline_leaky_relu_f32;
59896012 }
59906013 return nullptr;
6014+ case GGML_OP_CONV_2D_DW:
6015+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6016+ if (ggml_is_contiguous(src1)) {
6017+ return ctx->device->pipeline_conv2d_dw_whcn_f32;
6018+ } else if (ggml_is_contiguous_channels(src1)) {
6019+ return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6020+ }
6021+ }
6022+ return nullptr;
59916023 default:
59926024 return nullptr;
59936025 }
@@ -6014,6 +6046,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
60146046 case GGML_OP_REPEAT_BACK:
60156047 case GGML_OP_ROPE:
60166048 case GGML_OP_RMS_NORM:
6049+ case GGML_OP_CONV_2D_DW:
60176050 return true;
60186051 default:
60196052 return false;
@@ -6310,6 +6343,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
63106343 case GGML_OP_CONCAT:
63116344 case GGML_OP_UPSCALE:
63126345 case GGML_OP_UNARY:
6346+ case GGML_OP_CONV_2D_DW:
63136347 {
63146348 const uint32_t ne = ggml_nelements(dst);
63156349 if (ne > 262144) {
@@ -7096,6 +7130,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
70967130 }, dryrun);
70977131}
70987132
7133+ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7134+ vk_op_conv2d_dw_push_constants p{};
7135+ p.ne = ggml_nelements(dst);
7136+ p.channels = dst->ne[2];
7137+ p.batches = dst->ne[3];
7138+ p.dst_w = dst->ne[0];
7139+ p.dst_h = dst->ne[1];
7140+ p.src_w = src1->ne[0];
7141+ p.src_h = src1->ne[1];
7142+ p.knl_w = src0->ne[0];
7143+ p.knl_h = src0->ne[1];
7144+ p.stride_x = dst->op_params[0];
7145+ p.stride_y = dst->op_params[1];
7146+ p.pad_x = dst->op_params[2];
7147+ p.pad_y = dst->op_params[3];
7148+ p.dilation_x = dst->op_params[4];
7149+ p.dilation_y = dst->op_params[5];
7150+
7151+ GGML_ASSERT(src0->ne[3] == p.channels);
7152+ GGML_ASSERT(src1->ne[3] == p.batches);
7153+
7154+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7155+ }
7156+
70997157static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
71007158 const float * op_params = (const float *)dst->op_params;
71017159 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -8116,6 +8174,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
81168174 case GGML_OP_IM2COL:
81178175 case GGML_OP_TIMESTEP_EMBEDDING:
81188176 case GGML_OP_POOL_2D:
8177+ case GGML_OP_CONV_2D_DW:
81198178 case GGML_OP_RWKV_WKV6:
81208179 case GGML_OP_RWKV_WKV7:
81218180 case GGML_OP_LEAKY_RELU:
@@ -8179,6 +8238,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
81798238 case GGML_OP_IM2COL:
81808239 case GGML_OP_TIMESTEP_EMBEDDING:
81818240 case GGML_OP_POOL_2D:
8241+ case GGML_OP_CONV_2D_DW:
81828242 case GGML_OP_LEAKY_RELU:
81838243 {
81848244 // These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -8352,6 +8412,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83528412 case GGML_OP_POOL_2D:
83538413 ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
83548414
8415+ break;
8416+ case GGML_OP_CONV_2D_DW:
8417+ ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8418+
83558419 break;
83568420 case GGML_OP_LEAKY_RELU:
83578421 ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -8473,6 +8537,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
84738537 case GGML_OP_IM2COL:
84748538 case GGML_OP_TIMESTEP_EMBEDDING:
84758539 case GGML_OP_POOL_2D:
8540+ case GGML_OP_CONV_2D_DW:
84768541 case GGML_OP_RWKV_WKV6:
84778542 case GGML_OP_RWKV_WKV7:
84788543 case GGML_OP_LEAKY_RELU:
@@ -9442,6 +9507,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
94429507 case GGML_OP_COUNT_EQUAL:
94439508 case GGML_OP_IM2COL:
94449509 case GGML_OP_TIMESTEP_EMBEDDING:
9510+ case GGML_OP_CONV_2D_DW:
94459511 case GGML_OP_POOL_2D:
94469512 case GGML_OP_RWKV_WKV6:
94479513 case GGML_OP_RWKV_WKV7:
0 commit comments