@@ -8,6 +8,7 @@ struct conv_params {
88 int padding_x, padding_y;
99 int dilation_x, dilation_y;
1010 int channels, batches;
11+ int circular;
1112};
1213
1314struct kernel_bounds {
@@ -17,21 +18,34 @@ struct kernel_bounds {
1718
1819__device__ __forceinline__ kernel_bounds calculate_kernel_bounds (int out_x, int out_y, const conv_params & params) {
1920 kernel_bounds bounds;
20- bounds.y_min = max (0 , (params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
21- bounds.y_max =
22- min (params.kernel_h ,
23- (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
24- bounds.x_min = max (0 , (params.padding_x - out_x * params.stride_x + params.dilation_x - 1 ) / params.dilation_x );
25- bounds.x_max =
26- min (params.kernel_w ,
27- (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1 ) / params.dilation_x );
21+ if (params.circular ) {
22+ bounds.y_min = 0 ;
23+ bounds.y_max = params.kernel_h ;
24+ bounds.x_min = 0 ;
25+ bounds.x_max = params.kernel_w ;
26+ }
27+ else {
28+ bounds.y_min = max (0 , (params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
29+ bounds.y_max =
30+ min (params.kernel_h ,
31+ (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
32+ bounds.x_min = max (0 , (params.padding_x - out_x * params.stride_x + params.dilation_x - 1 ) / params.dilation_x );
33+ bounds.x_max =
34+ min (params.kernel_w ,
35+ (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1 ) / params.dilation_x );
36+
37+ }
2838 return bounds;
2939}
3040
3141__device__ __forceinline__ int calculate_input_coord (int out_coord, int kern_coord, int stride, int dilation, int padding) {
3242 return out_coord * stride + kern_coord * dilation - padding;
3343}
3444
45+ __device__ __forceinline__ int wrap_coord (int coord, int size) {
46+ return (coord % size + size) % size;
47+ }
48+
3549struct whcn_layout {
3650 __device__ static int input_index (int n, int c, int y, int x, const conv_params & params) {
3751 return n * (params.channels * params.in_w * params.in_h ) + c * params.in_w * params.in_h + y * params.in_w + x;
@@ -83,7 +97,8 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
8397 const int in_w, const int in_h, const int out_w, const int out_h,
8498 const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
8599 const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86- const int channels, const int batches) {
100+ const int channels, const int batches,
101+ const int circular) {
87102 const int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
88103 const int total_elements = batches * channels * out_h * out_w;
89104
@@ -92,26 +107,43 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
92107 }
93108
94109 conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
95- stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
110+ stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches, circular };
96111
97112 int batch_idx, channel_idx, out_y_idx, out_x_idx;
98113 Layout::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
99114
100115 T accumulator = 0 ;
101116 kernel_bounds bounds = calculate_kernel_bounds (out_x_idx, out_y_idx, params);
102117
103- for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
104- int in_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
118+ if (params.circular == 0 ) {
119+ for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
120+ int src_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
121+
122+ for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
123+ int src_x_idx = calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
124+
125+ const T input_val = input[Layout::input_index (batch_idx, channel_idx, src_y_idx, src_x_idx, params)];
126+ const T kernel_val = kernel[Layout::kernel_index (channel_idx, kern_y, kern_x, params)];
127+
128+ accumulator += input_val * kernel_val;
129+ }
130+ }
131+ }
132+ else {
133+ for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
134+ int in_y_idx = wrap_coord (calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y ), params.in_h );
105135
106- for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
107- int in_x_idx = calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
136+ for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
137+ int in_x_idx = wrap_coord ( calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x ), params. in_w );
108138
109- const T input_val = input[Layout::input_index (batch_idx, channel_idx, in_y_idx, in_x_idx , params)];
110- const T kernel_val = kernel[Layout::kernel_index (channel_idx, kern_y, kern_x, params)];
139+ const T input_val = input[Layout::input_index (batch_idx, channel_idx, src_y_idx, src_x_idx , params)];
140+ const T kernel_val = kernel[Layout::kernel_index (channel_idx, kern_y, kern_x, params)];
111141
112- accumulator += input_val * kernel_val;
142+ accumulator += input_val * kernel_val;
143+ }
113144 }
114145 }
146+
115147
116148 output[Layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
117149}
@@ -132,6 +164,7 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
132164 const int padding_y = p[3 ];
133165 const int dilation_x = p[4 ];
134166 const int dilation_y = p[5 ];
167+ const int circular = p[6 ];
135168
136169 const int in_w = input->ne [0 ];
137170 const int in_h = input->ne [1 ];
@@ -150,11 +183,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
150183 if (ggml_is_contiguous (input)) {
151184 conv2d_dw_kernel<float , whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
152185 x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
153- dilation_x, dilation_y, channels, batches);
186+ dilation_x, dilation_y, channels, batches, circular );
154187 } else if (ggml_is_contiguous_channels (input)) {
155188 conv2d_dw_kernel<float , cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
156189 x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
157- dilation_x, dilation_y, channels, batches);
190+ dilation_x, dilation_y, channels, batches, circular );
158191 } else {
159192 GGML_ABORT (" Unsupported memory layout for conv_2d_dw" );
160193 }
0 commit comments