Skip to content

Commit 1b62b49

Browse files
committed
Feat: Added cuda kernels
1 parent d7f5958 commit 1b62b49

File tree

4 files changed

+204
-75
lines changed

4 files changed

+204
-75
lines changed

ggml/src/ggml-cuda/conv2d-dw.cu

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct conv_params {
88
int padding_x, padding_y;
99
int dilation_x, dilation_y;
1010
int channels, batches;
11+
int circular;
1112
};
1213

1314
struct kernel_bounds {
@@ -17,21 +18,34 @@ struct kernel_bounds {
1718

1819
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
1920
kernel_bounds bounds;
20-
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
21-
bounds.y_max =
22-
min(params.kernel_h,
23-
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
24-
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
25-
bounds.x_max =
26-
min(params.kernel_w,
27-
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
21+
if (params.circular) {
22+
bounds.y_min = 0;
23+
bounds.y_max = params.kernel_h;
24+
bounds.x_min = 0;
25+
bounds.x_max = params.kernel_w;
26+
}
27+
else {
28+
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
29+
bounds.y_max =
30+
min(params.kernel_h,
31+
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
32+
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
33+
bounds.x_max =
34+
min(params.kernel_w,
35+
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
36+
37+
}
2838
return bounds;
2939
}
3040

3141
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
3242
return out_coord * stride + kern_coord * dilation - padding;
3343
}
3444

45+
__device__ __forceinline__ int wrap_coord(int coord, int size) {
46+
return (coord % size + size) % size;
47+
}
48+
3549
struct whcn_layout {
3650
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
3751
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
@@ -83,7 +97,8 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
8397
const int in_w, const int in_h, const int out_w, const int out_h,
8498
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
8599
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86-
const int channels, const int batches) {
100+
const int channels, const int batches,
101+
const int circular) {
87102
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
88103
const int total_elements = batches * channels * out_h * out_w;
89104

@@ -92,26 +107,43 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
92107
}
93108

94109
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
95-
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
110+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches, circular };
96111

97112
int batch_idx, channel_idx, out_y_idx, out_x_idx;
98113
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
99114

100115
T accumulator = 0;
101116
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
102117

103-
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
104-
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
118+
if (params.circular == 0) {
119+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
120+
int src_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
121+
122+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
123+
int src_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
124+
125+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, src_y_idx, src_x_idx, params)];
126+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
127+
128+
accumulator += input_val * kernel_val;
129+
}
130+
}
131+
}
132+
else {
133+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
134+
int in_y_idx = wrap_coord(calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y), params.in_h);
105135

106-
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
107-
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
136+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
137+
int in_x_idx = wrap_coord(calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x), params.in_w);
108138

109-
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110-
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
139+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, src_y_idx, src_x_idx, params)];
140+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
111141

112-
accumulator += input_val * kernel_val;
142+
accumulator += input_val * kernel_val;
143+
}
113144
}
114145
}
146+
115147

116148
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
117149
}
@@ -132,6 +164,7 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
132164
const int padding_y = p[3];
133165
const int dilation_x = p[4];
134166
const int dilation_y = p[5];
167+
const int circular = p[6];
135168

136169
const int in_w = input->ne[0];
137170
const int in_h = input->ne[1];
@@ -150,11 +183,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
150183
if (ggml_is_contiguous(input)) {
151184
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
152185
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
153-
dilation_x, dilation_y, channels, batches);
186+
dilation_x, dilation_y, channels, batches, circular);
154187
} else if (ggml_is_contiguous_channels(input)) {
155188
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
156189
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
157-
dilation_x, dilation_y, channels, batches);
190+
dilation_x, dilation_y, channels, batches, circular);
158191
} else {
159192
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
160193
}

ggml/src/ggml-cuda/conv2d-transpose.cu

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
#include "conv2d-transpose.cuh"
44
#include "ggml.h"
55

6+
7+
__device__ __forceinline__ int wrap_coord(int coord, int size) {
8+
return (coord % size + size) % size;
9+
}
10+
611
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
712
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
813
const int out_h, const int kernel_w, const int kernel_h, const int stride,
9-
const int c_in, const int c_out, const int batches) {
14+
const int c_in, const int c_out, const int batches,
15+
const int circular) {
1016
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
1117

1218
const int total_elements = out_w * out_h * c_out * batches;
@@ -22,28 +28,55 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const
2228

2329
float accumulator = 0;
2430
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
25-
26-
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
27-
for (int kh = 0; kh < kernel_h; ++kh) {
28-
int in_y = out_y_idx - kh;
29-
if (in_y < 0 || in_y % stride) continue;
30-
in_y /= stride;
31-
if (in_y >= in_h) continue;
32-
33-
for (int kw = 0; kw < kernel_w; ++kw) {
34-
int in_x = out_x_idx - kw;
35-
if (in_x < 0 || in_x % stride) continue;
36-
in_x /= stride;
37-
if (in_x >= in_w) continue;
38-
39-
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
40-
const int kernel_idx =
41-
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
42-
43-
float input_val = input[input_idx];
44-
half kern_val = kernel[kernel_idx];
45-
46-
accumulator += input_val * (float) kern_val;
31+
if (circular == 0) {
32+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
33+
for (int kh = 0; kh < kernel_h; ++kh) {
34+
int in_y = out_y_idx - kh;
35+
if (in_y < 0 || in_y % stride) continue;
36+
in_y /= stride;
37+
if (in_y >= in_h) continue;
38+
39+
for (int kw = 0; kw < kernel_w; ++kw) {
40+
int in_x = out_x_idx - kw;
41+
if (in_x < 0 || in_x % stride) continue;
42+
in_x /= stride;
43+
if (in_x >= in_w) continue;
44+
45+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
46+
const int kernel_idx =
47+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
48+
49+
float input_val = input[input_idx];
50+
half kern_val = kernel[kernel_idx];
51+
52+
accumulator += input_val * (float) kern_val;
53+
}
54+
}
55+
}
56+
}
57+
else {
58+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
59+
for (int kh = 0; kh < kernel_h; ++kh) {
60+
int in_y = out_y_idx - kh;
61+
if (in_y % stride) continue;
62+
in_y /= stride;
63+
in_y = wrap_coord(in_y, in_h);
64+
65+
for (int kw = 0; kw < kernel_w; ++kw) {
66+
int in_x = out_x_idx - kw;
67+
if (in_x % stride) continue;
68+
in_x /= stride;
69+
in_x = wrap_coord(in_x, in_w);
70+
71+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
72+
const int kernel_idx =
73+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
74+
75+
float input_val = input[input_idx];
76+
half kern_val = kernel[kernel_idx];
77+
78+
accumulator += input_val * (float) kern_val;
79+
}
4780
}
4881
}
4982
}
@@ -72,6 +105,7 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
72105
const int kernel_h = kernel->ne[1];
73106
const int stride = dst->op_params[0];
74107
const int batches = input->ne[3];
108+
const int circular = dst->op_params[1];
75109

76110
GGML_ASSERT(channels_in == kernel->ne[3]);
77111
GGML_ASSERT(stride > 0);
@@ -87,5 +121,5 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
87121

88122
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
89123
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
90-
channels_in, channels_out, batches);
124+
channels_in, channels_out, batches, circular);
91125
}

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct conv_params {
1111
const int64_t IC, OC;
1212
const int64_t B;
1313
const int64_t TOTAL;
14+
const int64_t CIRCULAR;
1415
};
1516

1617
struct kernel_bounds {
@@ -26,12 +27,24 @@ __device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
2627
return (a < b) ? a : b;
2728
}
2829

30+
__device__ __forceinline__ int wrap_coord(int coord, int size) {
31+
return (coord % size + size) % size;
32+
}
33+
2934
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
3035
kernel_bounds bounds;
31-
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
32-
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
33-
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
34-
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
36+
if (P.CIRCULAR) {
37+
bounds.y_min = 0;
38+
bounds.y_max = P.KH;
39+
bounds.x_min = 0;
40+
bounds.x_max = P.KW;
41+
}
42+
else {
43+
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
44+
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
45+
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
46+
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
47+
}
3548
return bounds;
3649
}
3750

@@ -84,19 +97,37 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
8497
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
8598

8699
float acc = 0.0f;
100+
if (P.CIRCULAR == 0) {
101+
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
102+
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
103+
104+
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
105+
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
87106

88-
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
89-
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
107+
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
108+
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
109+
110+
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
111+
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
112+
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
113+
}
114+
}
115+
}
116+
}
117+
else {
118+
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
119+
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
90120

91-
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
92-
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
121+
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
122+
const int64_t in_y = wrap_coord(calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y), P.IH);
93123

94-
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
95-
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
124+
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
125+
const int64_t in_x = wrap_coord(calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X), P.IW);
96126

97-
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
98-
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
99-
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
127+
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
128+
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
129+
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
130+
}
100131
}
101132
}
102133
}
@@ -141,6 +172,7 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
141172
const int PD_Y = p[3]; // padding_y
142173
const int DL_X = p[4]; // dilation_x
143174
const int DL_Y = p[5]; // dilation_y
175+
const int CIRCULAR = p[6];
144176

145177
// No cwhn
146178
GGML_ASSERT(p[6] == false);
@@ -156,7 +188,7 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
156188
const int B = input->ne[3]; // n_batches
157189

158190
const int64_t total = B * OC * OH * OW;
159-
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
191+
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total, CIRCULAR };
160192

161193
if (kernel->type == GGML_TYPE_F16) {
162194
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);

0 commit comments

Comments
 (0)