@@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d(
60646064 }
60656065}
60666066
6067+ // ggml_compute_forward_conv_2d_dw
6068+
6069+ struct ggml_conv_2d_dw_params {
6070+ int64_t channels;
6071+ int64_t batch;
6072+ int64_t src_w;
6073+ int64_t src_h;
6074+ int64_t dst_w;
6075+ int64_t dst_h;
6076+ int64_t knl_w;
6077+ int64_t knl_h;
6078+ int stride_x;
6079+ int stride_y;
6080+ int pad_x;
6081+ int pad_y;
6082+ int dilation_x;
6083+ int dilation_y;
6084+ };
6085+
6086+ static void ggml_compute_forward_conv_2d_dw_cwhn (
6087+ const ggml_compute_params * params,
6088+ const ggml_tensor * src,
6089+ const ggml_tensor * kernel,
6090+ ggml_tensor * dst,
6091+ const ggml_conv_2d_dw_params & p) {
6092+
6093+ const int64_t c = p.channels ;
6094+ const float * knl_data = (const float *)kernel->data ;
6095+
6096+ const int64_t rows_total = p.dst_h * p.batch ;
6097+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6098+ const int64_t row_start = params->ith * rows_per_thread;
6099+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6100+
6101+ #ifdef GGML_SIMD
6102+ const int64_t pkg_size = GGML_F32_EPR;
6103+ const int64_t pkg_count = c / pkg_size;
6104+ const int64_t c_pkg_end = pkg_count * pkg_size;
6105+ #else
6106+ const int64_t c_pkg_end = 0 ;
6107+ #endif
6108+
6109+ for (int64_t row = row_start; row < row_end; ++row) {
6110+ const int64_t dst_y = row % p.dst_h ;
6111+ const float * src_data = (const float *)src->data + (row / p.dst_h ) * p.src_w * p.src_h * c;
6112+ for (int64_t dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6113+ float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6114+ const int64_t src_y_base = dst_y * p.stride_y - p.pad_y ;
6115+ const int64_t src_x_base = dst_x * p.stride_x - p.pad_x ;
6116+
6117+ #ifdef GGML_SIMD
6118+ // Vectorized loop
6119+ for (int64_t c_i = 0 ; c_i < c_pkg_end; c_i += pkg_size) {
6120+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6121+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6122+ const int64_t src_y = src_y_base + knl_y * p.dilation_y ;
6123+ if (src_y < 0 || src_y >= p.src_h ) {
6124+ continue ;
6125+ }
6126+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6127+ const int64_t src_x = src_x_base + knl_x * p.dilation_x ;
6128+ if (src_x < 0 || src_x >= p.src_w ) {
6129+ continue ;
6130+ }
6131+ GGML_F32_VEC k = GGML_F32_VEC_LOAD (knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
6132+ GGML_F32_VEC s = GGML_F32_VEC_LOAD (src_data + (src_y * p.src_w + src_x) * c + c_i);
6133+ sum = GGML_F32_VEC_FMA (sum, k, s);
6134+ }
6135+ }
6136+ GGML_F32_VEC_STORE (dst_data + c_i, sum);
6137+ }
6138+ #endif
6139+ // Scalar loop
6140+ for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
6141+ float sum = 0 .0f ;
6142+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6143+ const int64_t src_y = src_y_base + knl_y * p.dilation_y ;
6144+ if (src_y < 0 || src_y >= p.src_h ) {
6145+ continue ;
6146+ }
6147+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6148+ const int64_t src_x = src_x_base + knl_x * p.dilation_x ;
6149+ if (src_x < 0 || src_x >= p.src_w ) {
6150+ continue ;
6151+ }
6152+ sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
6153+ * src_data[(src_y * p.src_w + src_x) * c + c_i];
6154+ }
6155+ }
6156+ dst_data[c_i] = sum;
6157+ }
6158+ }
6159+ }
6160+ }
6161+
6162+ static void ggml_compute_forward_conv_2d_dw_whcn (
6163+ const ggml_compute_params * params,
6164+ const ggml_tensor * src,
6165+ const ggml_tensor * kernel,
6166+ ggml_tensor * dst,
6167+ const ggml_conv_2d_dw_params & p) {
6168+
6169+ const int64_t n = p.channels * p.batch ;
6170+ const int64_t per_thread = (n + params->nth - 1 ) / params->nth ;
6171+ const int64_t start = params->ith * per_thread;
6172+ const int64_t end = MIN (start + per_thread, n);
6173+
6174+ for (int64_t i = start; i < end; ++i) {
6175+ const float * knl_data = (const float *)kernel->data + (i % p.channels ) * p.knl_w * p.knl_h ;
6176+ const float * src_data = (const float *)src->data + i * p.src_w * p.src_h ;
6177+ float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h ;
6178+
6179+ for (int64_t dst_y = 0 ; dst_y < p.dst_h ; ++dst_y) {
6180+ for (int64_t dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6181+
6182+ float sum = 0 .0f ;
6183+ for (int64_t knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6184+ const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y ;
6185+ if (src_y < 0 || src_y >= p.src_h ) {
6186+ continue ;
6187+ }
6188+ for (int64_t knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6189+ const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x ;
6190+ if (src_x < 0 || src_x >= p.src_w ) {
6191+ continue ;
6192+ }
6193+ sum += knl_data[knl_y * p.knl_w + knl_x]
6194+ * src_data[src_y * p.src_w + src_x];
6195+ }
6196+ }
6197+ dst_data[dst_y * p.dst_w + dst_x] = sum;
6198+ }
6199+ }
6200+ }
6201+ }
6202+
6203+ void ggml_compute_forward_conv_2d_dw (
6204+ const ggml_compute_params * params,
6205+ ggml_tensor * dst) {
6206+
6207+ const ggml_tensor * kernel = dst->src [0 ];
6208+ const ggml_tensor * src = dst->src [1 ];
6209+ ggml_conv_2d_dw_params p;
6210+ p.channels = src->ne [2 ];
6211+ p.batch = src->ne [3 ];
6212+ p.src_w = src->ne [0 ];
6213+ p.src_h = src->ne [1 ];
6214+ p.dst_w = dst->ne [0 ];
6215+ p.dst_h = dst->ne [1 ];
6216+ p.knl_w = kernel->ne [0 ];
6217+ p.knl_h = kernel->ne [1 ];
6218+ p.stride_x = dst->op_params [0 ];
6219+ p.stride_y = dst->op_params [1 ];
6220+ p.pad_x = dst->op_params [2 ];
6221+ p.pad_y = dst->op_params [3 ];
6222+ p.dilation_x = dst->op_params [4 ];
6223+ p.dilation_y = dst->op_params [5 ];
6224+
6225+ GGML_ASSERT (kernel->ne [3 ] == p.channels );
6226+ GGML_ASSERT (dst->ne [3 ] == p.batch );
6227+
6228+ if (ggml_is_contiguous (src)) {
6229+ ggml_compute_forward_conv_2d_dw_whcn (params, src, kernel, dst, p);
6230+ } else if (ggml_is_contiguous_channels (src)) {
6231+ // kernel should also have channels most contiguous in memory
6232+ GGML_ASSERT (kernel->nb [0 ] >= kernel->nb [2 ] && kernel->nb [1 ] >= kernel->nb [0 ]);
6233+ ggml_compute_forward_conv_2d_dw_cwhn (params, src, kernel, dst, p);
6234+ } else {
6235+ GGML_ABORT (" non-contiguous memory layout not supported" );
6236+ }
6237+ }
6238+
60676239// ggml_compute_forward_pool_1d_sk_p0
60686240
60696241static void ggml_compute_forward_pool_1d_sk_p0 (
0 commit comments