Skip to content

Commit f6ac084

Browse files
committed
Feat: Added vulkan circular tiling support
1 parent c5023da commit f6ac084

File tree

6 files changed

+310
-30
lines changed

6 files changed

+310
-30
lines changed

ggml/include/ggml.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,18 @@ extern "C" {
19431943
int d0, // dilation dimension 0
19441944
int d1); // dilation dimension 1
19451945

1946+
1947+
GGML_API struct ggml_tensor * ggml_conv_2d_circular(
1948+
struct ggml_context * ctx,
1949+
struct ggml_tensor * a, // convolution kernel
1950+
struct ggml_tensor * b, // data
1951+
int s0, // stride dimension 0
1952+
int s1, // stride dimension 1
1953+
int p0, // padding dimension 0
1954+
int p1, // padding dimension 1
1955+
int d0, // dilation dimension 0
1956+
int d1); // dilation dimension 1
1957+
19461958
GGML_API struct ggml_tensor * ggml_im2col_3d(
19471959
struct ggml_context * ctx,
19481960
struct ggml_tensor * a,
@@ -2016,6 +2028,19 @@ extern "C" {
20162028
int d0, // dilation dimension 0
20172029
int d1); // dilation dimension 1
20182030

2031+
2032+
// depthwise (via im2col and mul_mat)
2033+
GGML_API struct ggml_tensor * ggml_conv_2d_dw_circular(
2034+
struct ggml_context * ctx,
2035+
struct ggml_tensor * a, // convolution kernel
2036+
struct ggml_tensor * b, // data
2037+
int s0, // stride dimension 0
2038+
int s1, // stride dimension 1
2039+
int p0, // padding dimension 0
2040+
int p1, // padding dimension 1
2041+
int d0, // dilation dimension 0
2042+
int d1); // dilation dimension 1
2043+
20192044
// Depthwise 2D convolution
20202045
// may be faster than ggml_conv_2d_dw, but not available in all backends
20212046
// a: KW KH 1 C convolution kernel
@@ -2032,12 +2057,35 @@ extern "C" {
20322057
int dilation0,
20332058
int dilation1);
20342059

2060+
// Depthwise 2D convolution (on a torus)
2061+
// may be faster than ggml_conv_2d_dw, but not available in all backends
2062+
// a: KW KH 1 C convolution kernel
2063+
// b: W H C N input data
2064+
// res: W_out H_out C N
2065+
GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct_circular(
2066+
struct ggml_context * ctx,
2067+
struct ggml_tensor * a,
2068+
struct ggml_tensor * b,
2069+
int stride0,
2070+
int stride1,
2071+
int pad0,
2072+
int pad1,
2073+
int dilation0,
2074+
int dilation1);
2075+
20352076
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
20362077
struct ggml_context * ctx,
20372078
struct ggml_tensor * a,
20382079
struct ggml_tensor * b,
20392080
int stride);
20402081

2082+
// circular (on a torus)
2083+
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0_circular(
2084+
struct ggml_context * ctx,
2085+
struct ggml_tensor * a,
2086+
struct ggml_tensor * b,
2087+
int stride);
2088+
20412089
GGML_API struct ggml_tensor * ggml_conv_2d_direct(
20422090
struct ggml_context * ctx,
20432091
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
@@ -2048,6 +2096,17 @@ extern "C" {
20482096
int p1, // padding dimension 1
20492097
int d0, // dilation dimension 0
20502098
int d1); // dilation dimension 1
2099+
2100+
GGML_API struct ggml_tensor * ggml_conv_2d_direct_circular(
2101+
struct ggml_context * ctx,
2102+
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
2103+
struct ggml_tensor * b, // input data [W, H, C, N]
2104+
int s0, // stride dimension 0
2105+
int s1, // stride dimension 1
2106+
int p0, // padding dimension 0
2107+
int p1, // padding dimension 1
2108+
int d0, // dilation dimension 0
2109+
int d1); // dilation dimension 1
20512110

20522111
GGML_API struct ggml_tensor * ggml_conv_3d_direct(
20532112
struct ggml_context * ctx,
@@ -2156,6 +2215,15 @@ extern "C" {
21562215
int p2,
21572216
int p3);
21582217

2218+
// pad each dimension with values on the other side of the torus (looping around)
2219+
GGML_API struct ggml_tensor * ggml_pad_circular(
2220+
struct ggml_context * ctx,
2221+
struct ggml_tensor * a,
2222+
int p0,
2223+
int p1,
2224+
int p2,
2225+
int p3);
2226+
21592227
GGML_API struct ggml_tensor * ggml_pad_ext(
21602228
struct ggml_context * ctx,
21612229
struct ggml_tensor * a,
@@ -2169,6 +2237,20 @@ extern "C" {
21692237
int rp3
21702238
);
21712239

2240+
// circular padding
2241+
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
2242+
struct ggml_context * ctx,
2243+
struct ggml_tensor * a,
2244+
int lp0,
2245+
int rp0,
2246+
int lp1,
2247+
int rp1,
2248+
int lp2,
2249+
int rp2,
2250+
int lp3,
2251+
int rp3
2252+
);
2253+
21722254
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
21732255
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
21742256
struct ggml_context * ctx,

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ struct vk_op_pad_push_constants {
940940
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
941941
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
942942
uint32_t misalign_offsets;
943+
uint32_t circular;
943944

944945
uint32_t lp0; uint32_t rp0;
945946
uint32_t lp1; uint32_t rp1;
@@ -982,6 +983,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
982983
p.rp2 = dst->op_params[5];
983984
p.lp3 = dst->op_params[6];
984985
p.rp3 = dst->op_params[7];
986+
p.circular = dst->op_params[8];
985987

986988
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
987989
}
@@ -1249,6 +1251,8 @@ struct vk_op_conv2d_push_constants {
12491251
uint32_t KWKHmp; uint32_t KWKHL;
12501252
uint32_t OWmp; uint32_t OWL;
12511253
uint32_t OWOHmp; uint32_t OWOHL;
1254+
1255+
uint32_t circular;
12521256
};
12531257

12541258
template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
@@ -1297,6 +1301,8 @@ struct vk_op_conv_transpose_2d_push_constants {
12971301
uint32_t OWOHmp; uint32_t OWOHL;
12981302
uint32_t s0mp; uint32_t s0L;
12991303
uint32_t s1mp; uint32_t s1L;
1304+
1305+
uint32_t circular;
13001306
};
13011307

13021308
template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
@@ -1325,6 +1331,7 @@ struct vk_op_conv2d_dw_push_constants {
13251331
int32_t pad_y;
13261332
int32_t dilation_x;
13271333
int32_t dilation_y;
1334+
uint32_t circular;
13281335
};
13291336

13301337
struct vk_op_upscale_push_constants {
@@ -10420,6 +10427,8 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
1042010427
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
1042110428
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
1042210429

10430+
p.circular = static_cast<uint32_t>(dst->op_params[6]);
10431+
1042310432
GGML_ASSERT(ne03 == ne2);
1042410433
GGML_ASSERT(ne02 == ne12);
1042510434

@@ -10469,6 +10478,8 @@ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context
1046910478
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
1047010479
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
1047110480

10481+
p.circular = static_cast<uint32_t>(dst->op_params[1]);
10482+
1047210483
GGML_ASSERT(ne02 == ne2);
1047310484
GGML_ASSERT(ne03 == ne12);
1047410485

@@ -10492,6 +10503,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
1049210503
p.pad_y = dst->op_params[3];
1049310504
p.dilation_x = dst->op_params[4];
1049410505
p.dilation_y = dst->op_params[5];
10506+
p.circular = dst->op_params[6];
1049510507

1049610508
GGML_ASSERT(src0->ne[3] == p.channels);
1049710509
GGML_ASSERT(src1->ne[3] == p.batches);

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
1919
int pad_y;
2020
int dilation_x;
2121
int dilation_y;
22+
uint circular;
2223
} p;
2324

2425
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
@@ -27,6 +28,10 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
2728

2829
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
2930

31+
uint32_t wrap_coord(int coord, uint32_t size) {
32+
return uint32_t((uint(coord + int(size))) % size);
33+
}
34+
3035
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
3136
uint i0 = idx / p.dst_w;
3237
uint dst_x = idx - i0 * p.dst_w;
@@ -39,19 +44,35 @@ FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
3944
uint knl_i = c * p.knl_h * p.knl_w;
4045

4146
FLOAT_TYPE sum = 0.0;
42-
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
43-
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
44-
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
45-
continue;
47+
48+
if (p.circular != 0u) {
49+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
50+
int raw_y = int(dst_y) * p.stride_y + int(knl_y) * p.dilation_y - p.pad_y;
51+
uint src_y = wrap_coord(raw_y, p.src_h);
52+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
53+
int raw_x = int(dst_x) * p.stride_x + int(knl_x) * p.dilation_x - p.pad_x;
54+
uint src_x = wrap_coord(raw_x, p.src_w);
55+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
56+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
57+
sum = fma(v, k, sum);
58+
}
4659
}
47-
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
48-
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
49-
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
60+
}
61+
else {
62+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
63+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
64+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
5065
continue;
5166
}
52-
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
53-
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
54-
sum = fma(v, k, sum);
67+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
68+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
69+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
70+
continue;
71+
}
72+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
73+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
74+
sum = fma(v, k, sum);
75+
}
5576
}
5677
}
5778
return sum;
@@ -70,19 +91,34 @@ FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
7091
uint knl_row = p.knl_w * p.channels;
7192

7293
FLOAT_TYPE sum = 0.0;
73-
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
74-
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
75-
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
76-
continue;
94+
if (p.circular != 0u) {
95+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
96+
int raw_y = int(dst_y) * p.stride_y + int(knl_y) * p.dilation_y - p.pad_y;
97+
uint src_y = wrap_coord(raw_y, p.src_h);
98+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
99+
int raw_x = int(dst_x) * p.stride_x + int(knl_x) * p.dilation_x - p.pad_x;
100+
uint src_x = wrap_coord(raw_x, p.src_w);
101+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
102+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
103+
sum = fma(v, k, sum);
104+
}
77105
}
78-
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
79-
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
80-
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
106+
}
107+
else {
108+
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
109+
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
110+
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
81111
continue;
82112
}
83-
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
84-
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
85-
sum = fma(v, k, sum);
113+
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
114+
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
115+
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
116+
continue;
117+
}
118+
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
119+
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
120+
sum = fma(v, k, sum);
121+
}
86122
}
87123
}
88124
return sum;

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ layout(push_constant) uniform parameter {
7070
uint32_t s0mp; uint32_t s0L;
7171
uint32_t s1mp; uint32_t s1L;
7272
#endif
73+
74+
uint32_t circular;
7375
}
7476

7577
p;
@@ -174,6 +176,10 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T
174176
}
175177
#endif
176178

179+
uint32_t wrap_coord(int coord, uint32_t size) {
180+
return uint32_t((uint(coord + int(size))) % size);
181+
}
182+
177183
void main() {
178184
#ifdef COOPMAT2
179185
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
@@ -274,7 +280,8 @@ void main() {
274280
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
275281
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
276282
#endif
277-
283+
uint32_t H_pos;
284+
uint32_t W_pos;
278285
#ifdef TRANSPOSE
279286
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
280287
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
@@ -284,13 +291,15 @@ void main() {
284291
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
285292
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
286293
#endif
294+
H_pos = (p.circular != 0) ? wrap_coord(int(H_idx), p.H) : H_idx;
295+
W_pos = (p.circular != 0) ? wrap_coord(int(W_idx), p.W) : W_idx;
287296
uint32_t src_idx =
288-
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
297+
min(max(W_pos + H_pos * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
289298
float val = src_data[src_idx];
290299
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
291-
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
300+
|| H_pos >= p.H || W_pos >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
292301
#ifdef TRANSPOSE
293-
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
302+
|| (H_idx_x_s1 - H_pos * p.s1 != 0) || (W_idx_x_s0 - W_pos * p.s0 != 0)
294303
#endif
295304
) {
296305
val = 0.0;

0 commit comments

Comments
 (0)