@@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d(
60646064    }
60656065}
60666066
6067+ //  ggml_compute_forward_conv_2d_dw
6068+ 
6069+ struct  ggml_conv_2d_dw_params  {
6070+     int64_t  channels;
6071+     int64_t  batch;
6072+     int64_t  src_w;
6073+     int64_t  src_h;
6074+     int64_t  dst_w;
6075+     int64_t  dst_h;
6076+     int64_t  knl_w;
6077+     int64_t  knl_h;
6078+     int  stride_x;
6079+     int  stride_y;
6080+     int  pad_x;
6081+     int  pad_y;
6082+     int  dilation_x;
6083+     int  dilation_y;
6084+ };
6085+ 
6086+ static  void  ggml_compute_forward_conv_2d_dw_cwhn (
6087+         const  ggml_compute_params * params,
6088+         const  ggml_tensor * src,
6089+         const  ggml_tensor * kernel,
6090+         ggml_tensor * dst,
6091+         const  ggml_conv_2d_dw_params & p) {
6092+ 
6093+     const  int64_t  c = p.channels ;
6094+     const  float  * knl_data = (const  float  *)kernel->data ;
6095+ 
6096+     const  int64_t  rows_total = p.dst_h  * p.batch ;
6097+     const  int64_t  rows_per_thread = (rows_total + params->nth  - 1 ) / params->nth ;
6098+     const  int64_t  row_start = params->ith  * rows_per_thread;
6099+     const  int64_t  row_end = MIN (row_start + rows_per_thread, rows_total);
6100+ 
6101+ #ifdef  GGML_SIMD
6102+     const  int64_t  pkg_size = GGML_F32_EPR;
6103+     const  int64_t  pkg_count = c / pkg_size;
6104+     const  int64_t  c_pkg_end = pkg_count * pkg_size;
6105+ #else 
6106+     const  int64_t  c_pkg_end = 0 ;
6107+ #endif 
6108+ 
6109+     for  (int64_t  row = row_start; row < row_end; ++row) {
6110+         const  int64_t  dst_y = row % p.dst_h ;
6111+         const  float  * src_data = (const  float  *)src->data  + (row / p.dst_h ) * p.src_w  * p.src_h  * c;
6112+         for  (int64_t  dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6113+             float  * dst_data = (float  *)dst->data  + (row * p.dst_w  + dst_x) * c;
6114+             const  int64_t  src_y_base = dst_y * p.stride_y  - p.pad_y ;
6115+             const  int64_t  src_x_base = dst_x * p.stride_x  - p.pad_x ;
6116+ 
6117+ #ifdef  GGML_SIMD
6118+             //  Vectorized loop
6119+             for  (int64_t  c_i = 0 ; c_i < c_pkg_end; c_i += pkg_size) {
6120+                 GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6121+                 for  (int64_t  knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6122+                     const  int64_t  src_y = src_y_base + knl_y * p.dilation_y ;
6123+                     if  (src_y < 0  || src_y >= p.src_h ) {
6124+                         continue ;
6125+                     }
6126+                     for  (int64_t  knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6127+                         const  int64_t  src_x = src_x_base + knl_x * p.dilation_x ;
6128+                         if  (src_x < 0  || src_x >= p.src_w ) {
6129+                             continue ;
6130+                         }
6131+                         GGML_F32_VEC k = GGML_F32_VEC_LOAD (knl_data + (knl_y * p.knl_w  + knl_x) * c + c_i);
6132+                         GGML_F32_VEC s = GGML_F32_VEC_LOAD (src_data + (src_y * p.src_w  + src_x) * c + c_i);
6133+                         sum = GGML_F32_VEC_FMA (sum, k, s);
6134+                     }
6135+                 }
6136+                 GGML_F32_VEC_STORE (dst_data + c_i, sum);
6137+             }
6138+ #endif 
6139+             //  Scalar loop
6140+             for  (int64_t  c_i = c_pkg_end; c_i < c; ++c_i) {
6141+                 float  sum = 0 .0f ;
6142+                 for  (int64_t  knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6143+                     const  int64_t  src_y = src_y_base + knl_y * p.dilation_y ;
6144+                     if  (src_y < 0  || src_y >= p.src_h ) {
6145+                         continue ;
6146+                     }
6147+                     for  (int64_t  knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6148+                         const  int64_t  src_x = src_x_base + knl_x * p.dilation_x ;
6149+                         if  (src_x < 0  || src_x >= p.src_w ) {
6150+                             continue ;
6151+                         }
6152+                         sum += knl_data[(knl_y * p.knl_w  + knl_x) * c + c_i]
6153+                              * src_data[(src_y * p.src_w  + src_x) * c + c_i];
6154+                     }
6155+                 }
6156+                 dst_data[c_i] = sum;
6157+             }
6158+         }
6159+     }
6160+ }
6161+ 
6162+ static  void  ggml_compute_forward_conv_2d_dw_whcn (
6163+         const  ggml_compute_params * params,
6164+         const  ggml_tensor * src,
6165+         const  ggml_tensor * kernel,
6166+         ggml_tensor * dst,
6167+         const  ggml_conv_2d_dw_params & p) {
6168+ 
6169+     const  int64_t  n = p.channels  * p.batch ;
6170+     const  int64_t  per_thread = (n + params->nth  - 1 ) / params->nth ;
6171+     const  int64_t  start = params->ith  * per_thread;
6172+     const  int64_t  end = MIN (start + per_thread, n);
6173+ 
6174+     for  (int64_t  i = start; i < end; ++i) {
6175+         const  float  * knl_data = (const  float  *)kernel->data  + (i % p.channels ) * p.knl_w  * p.knl_h ;
6176+         const  float  * src_data = (const  float  *)src->data  + i * p.src_w  * p.src_h ;
6177+         float  * dst_data = (float  *)dst->data  + i * p.dst_w  * p.dst_h ;
6178+ 
6179+         for  (int64_t  dst_y = 0 ; dst_y < p.dst_h ; ++dst_y) {
6180+             for  (int64_t  dst_x = 0 ; dst_x < p.dst_w ; ++dst_x) {
6181+ 
6182+                 float  sum = 0 .0f ;
6183+                 for  (int64_t  knl_y = 0 ; knl_y < p.knl_h ; ++knl_y) {
6184+                     const  int64_t  src_y = dst_y * p.stride_y  + knl_y * p.dilation_y  - p.pad_y ;
6185+                     if  (src_y < 0  || src_y >= p.src_h ) {
6186+                         continue ;
6187+                     }
6188+                     for  (int64_t  knl_x = 0 ; knl_x < p.knl_w ; ++knl_x) {
6189+                         const  int64_t  src_x = dst_x * p.stride_x  + knl_x * p.dilation_x  - p.pad_x ;
6190+                         if  (src_x < 0  || src_x >= p.src_w ) {
6191+                             continue ;
6192+                         }
6193+                         sum += knl_data[knl_y * p.knl_w  + knl_x]
6194+                              * src_data[src_y * p.src_w  + src_x];
6195+                     }
6196+                 }
6197+                 dst_data[dst_y * p.dst_w  + dst_x] = sum;
6198+             }
6199+         }
6200+     }
6201+ }
6202+ 
6203+ void  ggml_compute_forward_conv_2d_dw (
6204+         const  ggml_compute_params * params,
6205+         ggml_tensor * dst) {
6206+ 
6207+     const  ggml_tensor * kernel = dst->src [0 ];
6208+     const  ggml_tensor * src = dst->src [1 ];
6209+     ggml_conv_2d_dw_params p;
6210+     p.channels  = src->ne [2 ];
6211+     p.batch  = src->ne [3 ];
6212+     p.src_w  = src->ne [0 ];
6213+     p.src_h  = src->ne [1 ];
6214+     p.dst_w  = dst->ne [0 ];
6215+     p.dst_h  = dst->ne [1 ];
6216+     p.knl_w  = kernel->ne [0 ];
6217+     p.knl_h  = kernel->ne [1 ];
6218+     p.stride_x  = dst->op_params [0 ];
6219+     p.stride_y  = dst->op_params [1 ];
6220+     p.pad_x  = dst->op_params [2 ];
6221+     p.pad_y  = dst->op_params [3 ];
6222+     p.dilation_x  = dst->op_params [4 ];
6223+     p.dilation_y  = dst->op_params [5 ];
6224+ 
6225+     GGML_ASSERT (kernel->ne [3 ] == p.channels );
6226+     GGML_ASSERT (dst->ne [3 ] == p.batch );
6227+ 
6228+     if  (ggml_is_contiguous (src)) {
6229+         ggml_compute_forward_conv_2d_dw_whcn (params, src, kernel, dst, p);
6230+     } else  if  (ggml_is_contiguous_channels (src)) {
6231+         //  kernel should also have channels most contiguous in memory
6232+         GGML_ASSERT (kernel->nb [0 ] >= kernel->nb [2 ] && kernel->nb [1 ] >= kernel->nb [0 ]);
6233+         ggml_compute_forward_conv_2d_dw_cwhn (params, src, kernel, dst, p);
6234+     } else  {
6235+         GGML_ABORT (" non-contiguous memory layout not supported"  );
6236+     }
6237+ }
6238+ 
60676239//  ggml_compute_forward_pool_1d_sk_p0
60686240
60696241static  void  ggml_compute_forward_pool_1d_sk_p0 (
0 commit comments