Skip to content

Commit 0482de9

Browse files
authored
vulkan : kernels for depthwise 2D convolution (CONV_2D_DW) (#1204)
* vulkan : add kernels for depthwise 2d convolution (OP_CONV_2D_DW) * review: remove src_x/y < 0 checks; add performance tests
1 parent f3a375f commit 0482de9

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ struct vk_device_struct {
368368
vk_pipeline pipeline_rwkv_wkv6_f32;
369369
vk_pipeline pipeline_rwkv_wkv7_f32;
370370
vk_pipeline pipeline_opt_step_adamw_f32;
371+
vk_pipeline pipeline_conv2d_dw_whcn_f32;
372+
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
371373

372374
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
373375
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -680,6 +682,24 @@ struct vk_op_rwkv_wkv7_push_constants {
680682
uint32_t H;
681683
};
682684

685+
struct vk_op_conv2d_dw_push_constants {
686+
uint32_t ne;
687+
uint32_t batches;
688+
uint32_t channels;
689+
uint32_t dst_w;
690+
uint32_t dst_h;
691+
uint32_t src_w;
692+
uint32_t src_h;
693+
uint32_t knl_w;
694+
uint32_t knl_h;
695+
int32_t stride_x;
696+
int32_t stride_y;
697+
int32_t pad_x;
698+
int32_t pad_y;
699+
int32_t dilation_x;
700+
int32_t dilation_y;
701+
};
702+
683703
struct vk_op_upscale_push_constants {
684704
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
685705
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2529,6 +2549,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
25292549

25302550
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
25312551

2552+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2553+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2554+
25322555
for (auto &c : compiles) {
25332556
c.wait();
25342557
}
@@ -5988,6 +6011,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
59886011
return ctx->device->pipeline_leaky_relu_f32;
59896012
}
59906013
return nullptr;
6014+
case GGML_OP_CONV_2D_DW:
6015+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6016+
if (ggml_is_contiguous(src1)) {
6017+
return ctx->device->pipeline_conv2d_dw_whcn_f32;
6018+
} else if (ggml_is_contiguous_channels(src1)) {
6019+
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6020+
}
6021+
}
6022+
return nullptr;
59916023
default:
59926024
return nullptr;
59936025
}
@@ -6014,6 +6046,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
60146046
case GGML_OP_REPEAT_BACK:
60156047
case GGML_OP_ROPE:
60166048
case GGML_OP_RMS_NORM:
6049+
case GGML_OP_CONV_2D_DW:
60176050
return true;
60186051
default:
60196052
return false;
@@ -6310,6 +6343,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
63106343
case GGML_OP_CONCAT:
63116344
case GGML_OP_UPSCALE:
63126345
case GGML_OP_UNARY:
6346+
case GGML_OP_CONV_2D_DW:
63136347
{
63146348
const uint32_t ne = ggml_nelements(dst);
63156349
if (ne > 262144) {
@@ -7096,6 +7130,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
70967130
}, dryrun);
70977131
}
70987132

7133+
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7134+
vk_op_conv2d_dw_push_constants p{};
7135+
p.ne = ggml_nelements(dst);
7136+
p.channels = dst->ne[2];
7137+
p.batches = dst->ne[3];
7138+
p.dst_w = dst->ne[0];
7139+
p.dst_h = dst->ne[1];
7140+
p.src_w = src1->ne[0];
7141+
p.src_h = src1->ne[1];
7142+
p.knl_w = src0->ne[0];
7143+
p.knl_h = src0->ne[1];
7144+
p.stride_x = dst->op_params[0];
7145+
p.stride_y = dst->op_params[1];
7146+
p.pad_x = dst->op_params[2];
7147+
p.pad_y = dst->op_params[3];
7148+
p.dilation_x = dst->op_params[4];
7149+
p.dilation_y = dst->op_params[5];
7150+
7151+
GGML_ASSERT(src0->ne[3] == p.channels);
7152+
GGML_ASSERT(src1->ne[3] == p.batches);
7153+
7154+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7155+
}
7156+
70997157
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
71007158
const float * op_params = (const float *)dst->op_params;
71017159
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -8116,6 +8174,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
81168174
case GGML_OP_IM2COL:
81178175
case GGML_OP_TIMESTEP_EMBEDDING:
81188176
case GGML_OP_POOL_2D:
8177+
case GGML_OP_CONV_2D_DW:
81198178
case GGML_OP_RWKV_WKV6:
81208179
case GGML_OP_RWKV_WKV7:
81218180
case GGML_OP_LEAKY_RELU:
@@ -8179,6 +8238,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
81798238
case GGML_OP_IM2COL:
81808239
case GGML_OP_TIMESTEP_EMBEDDING:
81818240
case GGML_OP_POOL_2D:
8241+
case GGML_OP_CONV_2D_DW:
81828242
case GGML_OP_LEAKY_RELU:
81838243
{
81848244
// These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -8352,6 +8412,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83528412
case GGML_OP_POOL_2D:
83538413
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
83548414

8415+
break;
8416+
case GGML_OP_CONV_2D_DW:
8417+
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8418+
83558419
break;
83568420
case GGML_OP_LEAKY_RELU:
83578421
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -8473,6 +8537,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
84738537
case GGML_OP_IM2COL:
84748538
case GGML_OP_TIMESTEP_EMBEDDING:
84758539
case GGML_OP_POOL_2D:
8540+
case GGML_OP_CONV_2D_DW:
84768541
case GGML_OP_RWKV_WKV6:
84778542
case GGML_OP_RWKV_WKV7:
84788543
case GGML_OP_LEAKY_RELU:
@@ -9442,6 +9507,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
94429507
case GGML_OP_COUNT_EQUAL:
94439508
case GGML_OP_IM2COL:
94449509
case GGML_OP_TIMESTEP_EMBEDDING:
9510+
case GGML_OP_CONV_2D_DW:
94459511
case GGML_OP_POOL_2D:
94469512
case GGML_OP_RWKV_WKV6:
94479513
case GGML_OP_RWKV_WKV7:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
5+
layout (push_constant) uniform parameter
6+
{
7+
uint ne;
8+
uint batches;
9+
uint channels;
10+
uint dst_w;
11+
uint dst_h;
12+
uint src_w;
13+
uint src_h;
14+
uint knl_w;
15+
uint knl_h;
16+
int stride_x;
17+
int stride_y;
18+
int pad_x;
19+
int pad_y;
20+
int dilation_x;
21+
int dilation_y;
22+
} p;
23+
24+
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
25+
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
26+
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
27+
28+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
29+
30+
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
31+
uint i0 = idx / p.dst_w;
32+
uint dst_x = idx - i0 * p.dst_w;
33+
uint i1 = i0 / p.dst_h;
34+
uint dst_y = i0 - i1 * p.dst_h;
35+
uint n = i1 / p.channels;
36+
uint c = i1 - n * p.channels;
37+
38+
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
39+
uint knl_i = c * p.knl_h * p.knl_w;
40+
41+
FLOAT_TYPE sum = 0.0;
42+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
43+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
44+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
45+
continue;
46+
}
47+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
48+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
49+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
50+
continue;
51+
}
52+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
53+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
54+
sum = fma(v, k, sum);
55+
}
56+
}
57+
return sum;
58+
}
59+
60+
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
61+
uint i0 = idx / p.channels;
62+
uint c = idx - i0 * p.channels;
63+
uint i1 = i0 / p.dst_w;
64+
uint dst_x = i0 - i1 * p.dst_w;
65+
uint n = i1 / p.dst_h;
66+
uint dst_y = i1 - n * p.dst_h;
67+
68+
uint src_i = n * p.channels * p.src_h * p.src_w;
69+
uint src_row = p.src_w * p.channels;
70+
uint knl_row = p.knl_w * p.channels;
71+
72+
FLOAT_TYPE sum = 0.0;
73+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
74+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
75+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
76+
continue;
77+
}
78+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
79+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
80+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
81+
continue;
82+
}
83+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
84+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
85+
sum = fma(v, k, sum);
86+
}
87+
}
88+
return sum;
89+
}
90+
91+
void main() {
92+
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
93+
if (idx >= p.ne) {
94+
return;
95+
}
96+
97+
FLOAT_TYPE result =
98+
#ifdef WHCN
99+
conv_2d_dw_whcn(idx);
100+
#else
101+
conv_2d_dw_cwhn(idx);
102+
#endif
103+
dst_data[idx] = D_TYPE(result);
104+
}

src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ void process_shaders() {
544544

545545
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
546546

547+
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
548+
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
549+
547550
for (auto &c : compiles) {
548551
c.wait();
549552
}

tests/test-backend-ops.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2762,6 +2762,48 @@ struct test_im2col : public test_case {
27622762
}
27632763
};
27642764

2765+
// GGML_OP_CONV_2D_DW
2766+
struct test_conv_2d_dw : public test_case {
2767+
const std::array<int64_t, 4> ne_input;
2768+
const std::array<int64_t, 4> ne_kernel;
2769+
const int stride;
2770+
const int padding;
2771+
const int dilation;
2772+
const bool cwhn;
2773+
2774+
std::string vars() override {
2775+
return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
2776+
}
2777+
2778+
test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},
2779+
std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},
2780+
int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
2781+
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
2782+
2783+
ggml_tensor * build_graph(ggml_context * ctx) override {
2784+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
2785+
ggml_set_name(input, "input");
2786+
2787+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
2788+
ggml_set_name(kernel, "kernel");
2789+
2790+
if (cwhn) {
2791+
// change memory layout to channel-most-contiguous (CWHN),
2792+
// then permute it back so NE matches the original input
2793+
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
2794+
input = ggml_permute(ctx, input, 2, 0, 1, 3);
2795+
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
2796+
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
2797+
}
2798+
2799+
ggml_tensor * out = ggml_conv_2d_dw_direct(
2800+
ctx, kernel, input,
2801+
stride, stride, padding, padding, dilation, dilation);
2802+
ggml_set_name(out, "out");
2803+
return out;
2804+
}
2805+
};
2806+
27652807
// GGML_OP_CONCAT
27662808
struct test_concat : public test_case {
27672809
const ggml_type type;
@@ -3972,6 +4014,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39724014
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
39734015
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
39744016

4017+
test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
4018+
test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
4019+
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
4020+
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
4021+
39754022
test_cases.emplace_back(new test_conv_transpose_1d());
39764023
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
39774024
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
@@ -4546,6 +4593,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
45464593
}
45474594
}
45484595

4596+
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
4597+
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
4598+
45494599
return test_cases;
45504600
}
45514601

0 commit comments

Comments
 (0)