@@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d(
6064
6064
}
6065
6065
}
6066
6066
6067
+ // ggml_compute_forward_conv_2d_dw
6068
+
6069
+ struct ggml_conv_2d_dw_params {
6070
+ int64_t channels;
6071
+ int64_t batch;
6072
+ int64_t src_w;
6073
+ int64_t src_h;
6074
+ int64_t dst_w;
6075
+ int64_t dst_h;
6076
+ int64_t knl_w;
6077
+ int64_t knl_h;
6078
+ int stride_x;
6079
+ int stride_y;
6080
+ int pad_x;
6081
+ int pad_y;
6082
+ int dilation_x;
6083
+ int dilation_y;
6084
+ };
6085
+
6086
+ static void ggml_compute_forward_conv_2d_dw_cwhn (
6087
+ const ggml_compute_params * params,
6088
+ const ggml_tensor * src,
6089
+ const ggml_tensor * kernel,
6090
+ ggml_tensor * dst,
6091
+ const ggml_conv_2d_dw_params & p) {
6092
+
6093
+ const int64_t c = p.channels ;
6094
+ const float * knl_data = (const float *)kernel->data ;
6095
+
6096
+ const int64_t rows_total = p.dst_h * p.batch ;
6097
+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6098
+ const int64_t row_start = params->ith * rows_per_thread;
6099
+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6100
+
6101
+ #ifdef GGML_SIMD
6102
+ const int64_t pkg_size = GGML_F32_EPR;
6103
+ const int64_t pkg_count = c / pkg_size;
6104
+ const int64_t c_pkg_end = pkg_count * pkg_size;
6105
+ #else
6106
+ const int64_t c_pkg_end = 0 ;
6107
+ #endif
6108
+
6109
+ for (int64_t row = row_start; row < row_end; ++row) {
6110
+ const int64_t dst_y = row % p.dst_h ;
6111
+ const float * src_data = (const float *)src->data + (row / p.dst_h ) * p.src_w * p.src_h * c;
6112
+ for (int64_t dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6113
+ float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6114
+ const int64_t src_y_base = dst_y * p.stride_y - p.pad_y ;
6115
+ const int64_t src_x_base = dst_x * p.stride_x - p.pad_x ;
6116
+
6117
+ #ifdef GGML_SIMD
6118
+ // Vectorized loop
6119
+ for (int64_t c_i = 0 ; c_i < c_pkg_end; c_i += pkg_size) {
6120
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6121
+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6122
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y ;
6123
+ if (src_y < 0 || src_y >= p.src_h ) {
6124
+ continue ;
6125
+ }
6126
+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6127
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x ;
6128
+ if (src_x < 0 || src_x >= p.src_w ) {
6129
+ continue ;
6130
+ }
6131
+ GGML_F32_VEC k = GGML_F32_VEC_LOAD (knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
6132
+ GGML_F32_VEC s = GGML_F32_VEC_LOAD (src_data + (src_y * p.src_w + src_x) * c + c_i);
6133
+ sum = GGML_F32_VEC_FMA (sum, k, s);
6134
+ }
6135
+ }
6136
+ GGML_F32_VEC_STORE (dst_data + c_i, sum);
6137
+ }
6138
+ #endif
6139
+ // Scalar loop
6140
+ for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
6141
+ float sum = 0 .0f ;
6142
+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6143
+ const int64_t src_y = src_y_base + knl_y * p.dilation_y ;
6144
+ if (src_y < 0 || src_y >= p.src_h ) {
6145
+ continue ;
6146
+ }
6147
+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6148
+ const int64_t src_x = src_x_base + knl_x * p.dilation_x ;
6149
+ if (src_x < 0 || src_x >= p.src_w ) {
6150
+ continue ;
6151
+ }
6152
+ sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
6153
+ * src_data[(src_y * p.src_w + src_x) * c + c_i];
6154
+ }
6155
+ }
6156
+ dst_data[c_i] = sum;
6157
+ }
6158
+ }
6159
+ }
6160
+ }
6161
+
6162
+ static void ggml_compute_forward_conv_2d_dw_whcn (
6163
+ const ggml_compute_params * params,
6164
+ const ggml_tensor * src,
6165
+ const ggml_tensor * kernel,
6166
+ ggml_tensor * dst,
6167
+ const ggml_conv_2d_dw_params & p) {
6168
+
6169
+ const int64_t n = p.channels * p.batch ;
6170
+ const int64_t per_thread = (n + params->nth - 1 ) / params->nth ;
6171
+ const int64_t start = params->ith * per_thread;
6172
+ const int64_t end = MIN (start + per_thread, n);
6173
+
6174
+ for (int64_t i = start; i < end; ++i) {
6175
+ const float * knl_data = (const float *)kernel->data + (i % p.channels ) * p.knl_w * p.knl_h ;
6176
+ const float * src_data = (const float *)src->data + i * p.src_w * p.src_h ;
6177
+ float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h ;
6178
+
6179
+ for (int64_t dst_y = 0 ; dst_y < p.dst_h ; ++dst_y) {
6180
+ for (int64_t dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6181
+
6182
+ float sum = 0 .0f ;
6183
+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6184
+ const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y ;
6185
+ if (src_y < 0 || src_y >= p.src_h ) {
6186
+ continue ;
6187
+ }
6188
+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6189
+ const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x ;
6190
+ if (src_x < 0 || src_x >= p.src_w ) {
6191
+ continue ;
6192
+ }
6193
+ sum += knl_data[knl_y * p.knl_w + knl_x]
6194
+ * src_data[src_y * p.src_w + src_x];
6195
+ }
6196
+ }
6197
+ dst_data[dst_y * p.dst_w + dst_x] = sum;
6198
+ }
6199
+ }
6200
+ }
6201
+ }
6202
+
6203
+ void ggml_compute_forward_conv_2d_dw (
6204
+ const ggml_compute_params * params,
6205
+ ggml_tensor * dst) {
6206
+
6207
+ const ggml_tensor * kernel = dst->src [0 ];
6208
+ const ggml_tensor * src = dst->src [1 ];
6209
+ ggml_conv_2d_dw_params p;
6210
+ p.channels = src->ne [2 ];
6211
+ p.batch = src->ne [3 ];
6212
+ p.src_w = src->ne [0 ];
6213
+ p.src_h = src->ne [1 ];
6214
+ p.dst_w = dst->ne [0 ];
6215
+ p.dst_h = dst->ne [1 ];
6216
+ p.knl_w = kernel->ne [0 ];
6217
+ p.knl_h = kernel->ne [1 ];
6218
+ p.stride_x = dst->op_params [0 ];
6219
+ p.stride_y = dst->op_params [1 ];
6220
+ p.pad_x = dst->op_params [2 ];
6221
+ p.pad_y = dst->op_params [3 ];
6222
+ p.dilation_x = dst->op_params [4 ];
6223
+ p.dilation_y = dst->op_params [5 ];
6224
+
6225
+ GGML_ASSERT (kernel->ne [3 ] == p.channels );
6226
+ GGML_ASSERT (dst->ne [3 ] == p.batch );
6227
+
6228
+ if (ggml_is_contiguous (src)) {
6229
+ ggml_compute_forward_conv_2d_dw_whcn (params, src, kernel, dst, p);
6230
+ } else if (ggml_is_contiguous_channels (src)) {
6231
+ // kernel should also have channels most contiguous in memory
6232
+ GGML_ASSERT (kernel->nb [0 ] >= kernel->nb [2 ] && kernel->nb [1 ] >= kernel->nb [0 ]);
6233
+ ggml_compute_forward_conv_2d_dw_cwhn (params, src, kernel, dst, p);
6234
+ } else {
6235
+ GGML_ABORT (" non-contiguous memory layout not supported" );
6236
+ }
6237
+ }
6238
+
6067
6239
// ggml_compute_forward_pool_1d_sk_p0
6068
6240
6069
6241
static void ggml_compute_forward_pool_1d_sk_p0 (
0 commit comments