@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
72077207    ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
72087208}
72097209
7210+ //  ggml_compute_forward_conv_3d
7211+ 
7212+ static  void  ggml_compute_forward_conv_3d_impl (const  ggml_compute_params * params,
7213+                                               const  ggml_tensor *         kernel,
7214+                                               const  ggml_tensor *         src,
7215+                                               ggml_tensor *               dst,
7216+                                               ggml_type                   kernel_type) {
7217+ 
7218+     GGML_ASSERT (ggml_is_contiguous (kernel));
7219+     GGML_ASSERT (kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
7220+     GGML_ASSERT (kernel->type  == kernel_type);
7221+ 
7222+     const  ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
7223+ 
7224+     const  int32_t  s0 = dst->op_params [0 ];
7225+     const  int32_t  s1 = dst->op_params [1 ];
7226+     const  int32_t  s2 = dst->op_params [2 ];
7227+     const  int32_t  p0 = dst->op_params [3 ];
7228+     const  int32_t  p1 = dst->op_params [4 ];
7229+     const  int32_t  p2 = dst->op_params [5 ];
7230+     const  int32_t  d0 = dst->op_params [6 ];
7231+     const  int32_t  d1 = dst->op_params [7 ];
7232+     const  int32_t  d2 = dst->op_params [8 ];
7233+     const  int32_t  c  = dst->op_params [9 ];
7234+     const  int32_t  n  = dst->op_params [10 ];
7235+     const  int32_t  oc = dst->op_params [11 ];
7236+ 
7237+     const  int64_t  src_w = src->ne [0 ];
7238+     const  int64_t  src_h = src->ne [1 ];
7239+     const  int64_t  src_d = src->ne [2 ];
7240+     const  int64_t  knl_w = kernel->ne [0 ];
7241+     const  int64_t  knl_h = kernel->ne [1 ];
7242+     const  int64_t  knl_d = kernel->ne [2 ];
7243+     const  int64_t  dst_w = dst->ne [0 ];
7244+     const  int64_t  dst_h = dst->ne [1 ];
7245+     const  int64_t  dst_d = dst->ne [2 ];
7246+ 
7247+     const  float  * src_data = (float  *) src->data ;
7248+     void   * knl_data       = kernel->data ;
7249+     float  * dst_data       = (float  *) dst->data ;
7250+ 
7251+     const  int64_t  knl_n_per_channel = knl_w * knl_h * knl_d;
7252+     const  int64_t  knl_n_total       = knl_n_per_channel * c;
7253+     const  int64_t  patch_total       = n * dst_w * dst_h * dst_d;
7254+ 
7255+     const  int64_t  space_per_patch   = knl_n_total * traits->type_size  + oc * sizeof (float );
7256+     const  int64_t  batch_size        = params->wsize  / space_per_patch;
7257+     const  int64_t  patches_per_batch = batch_size > 8  ? (batch_size / 8 ) * 8  : batch_size;
7258+     const  int64_t  batch_n           = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
7259+ 
7260+     GGML_ASSERT (patches_per_batch > 0  && batch_size >= 1 );
7261+ 
7262+     void  * tmp = params->wdata ;
7263+ 
7264+     for  (int64_t  batch_i = 0 ; batch_i < batch_n; ++batch_i) {
7265+         const  int64_t  patch_start_batch = batch_i * patches_per_batch;
7266+         const  int64_t  patch_end_batch   = std::min (patch_start_batch + patches_per_batch, patch_total);
7267+         const  int64_t  patch_n_in_batch  = patch_end_batch - patch_start_batch;
7268+ 
7269+         const  int64_t  patch_per_thread  = (patch_n_in_batch + params->nth  - 1 ) / params->nth ;
7270+         const  int64_t  patch_start       = patch_start_batch + params->ith  * patch_per_thread;
7271+         const  int64_t  patch_end         = std::min (patch_start + patch_per_thread, patch_end_batch);
7272+ 
7273+         for  (int64_t  p = patch_start; p < patch_end; ++p) {
7274+             const  int64_t  p_in_batch = p % (dst_w * dst_h * dst_d);
7275+             const  int64_t  p_in_depth = p_in_batch % (dst_w * dst_h);
7276+             const  int64_t  batch_idx  = p / (dst_w * dst_h * dst_d);
7277+             const  int64_t  dst_z      = p_in_batch / (dst_w * dst_h);
7278+             const  int64_t  dst_y      = p_in_depth / dst_w;
7279+             const  int64_t  dst_x      = p_in_depth % dst_w;
7280+ 
7281+             char  * dst_row = (char  *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size ;
7282+ 
7283+             for  (int64_t  ic = 0 ; ic < c; ++ic) {
7284+                 for  (int64_t  kz = 0 ; kz < knl_d; ++kz) {
7285+                     for  (int64_t  ky = 0 ; ky < knl_h; ++ky) {
7286+                         for  (int64_t  kx = 0 ; kx < knl_w; ++kx) {
7287+                             const  int64_t  sz = dst_z * s2 + kz * d2 - p2;
7288+                             const  int64_t  sy = dst_y * s1 + ky * d1 - p1;
7289+                             const  int64_t  sx = dst_x * s0 + kx * d0 - p0;
7290+ 
7291+                             int64_t  dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
7292+ 
7293+                             float  src_val;
7294+                             if  (sz < 0  || sz >= src_d || sy < 0  || sy >= src_h || sx < 0  || sx >= src_w) {
7295+                                 src_val = 0 .0f ;
7296+                             } else  {
7297+                                 const  int64_t  cn_idx = batch_idx * c + ic;
7298+                                 const  float  * src_ptr = (const  float  *)((const  char  *)src_data + sx*src->nb [0 ] + sy*src->nb [1 ] + sz*src->nb [2 ] + cn_idx*src->nb [3 ]);
7299+                                 src_val = *src_ptr;
7300+                             }
7301+ 
7302+                             char  * element_ptr = dst_row + dst_idx * traits->type_size ;
7303+                             if  (kernel_type == GGML_TYPE_F32) {
7304+                                 *(float  *)element_ptr = src_val;
7305+                             } else  if  (kernel_type == GGML_TYPE_F16) {
7306+                                 *(ggml_fp16_t  *)element_ptr = GGML_CPU_FP32_TO_FP16 (src_val);
7307+                             }
7308+                         }
7309+                     }
7310+                 }
7311+             }
7312+         }
7313+ 
7314+         ggml_barrier (params->threadpool );
7315+ 
7316+         float  * gemm_output = (float  *) ((char  *) tmp + patches_per_batch * knl_n_total * traits->type_size );
7317+         ggml_call_mul_mat (kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
7318+ 
7319+         ggml_barrier (params->threadpool );
7320+ 
7321+         const  int64_t  permute_per_thread = (patch_n_in_batch + params->nth  - 1 ) / params->nth ;
7322+         const  int64_t  permute_start = params->ith  * permute_per_thread;
7323+         const  int64_t  permute_end = std::min (permute_start + permute_per_thread, patch_n_in_batch);
7324+ 
7325+         for  (int64_t  i = permute_start; i < permute_end; ++i) {
7326+             const  int64_t  p = patch_start_batch + i;
7327+             const  int64_t  p_in_batch = p % (dst_w * dst_h * dst_d);
7328+             const  int64_t  p_in_depth = p_in_batch % (dst_w * dst_h);
7329+             const  int64_t  batch_idx  = p / (dst_w * dst_h * dst_d);
7330+             const  int64_t  dst_z      = p_in_batch / (dst_w * dst_h);
7331+             const  int64_t  dst_y      = p_in_depth / dst_w;
7332+             const  int64_t  dst_x      = p_in_depth % dst_w;
7333+ 
7334+             for  (int64_t  ioc = 0 ; ioc < oc; ++ioc) {
7335+                 const  float  value = gemm_output[i * oc + ioc];
7336+                 const  int64_t  ocn_idx = batch_idx * oc + ioc;
7337+                 float  * dst_ptr = (float  *)((char  *)dst_data + dst_x*dst->nb [0 ] + dst_y*dst->nb [1 ] + dst_z*dst->nb [2 ] + ocn_idx*dst->nb [3 ]);
7338+                 *dst_ptr = value;
7339+             }
7340+         }
7341+     }
7342+ }
7343+ 
7344+ void  ggml_compute_forward_conv_3d (
7345+         const  ggml_compute_params * params,
7346+         ggml_tensor * dst) {
7347+     const  ggml_tensor * src0 = dst->src [0 ];
7348+     const  ggml_tensor * src1 = dst->src [1 ];
7349+     ggml_compute_forward_conv_3d_impl (params, src0, src1, dst, src0->type );
7350+ }
7351+ 
72107352//  ggml_compute_forward_conv_transpose_2d
72117353
72127354void  ggml_compute_forward_conv_transpose_2d (
0 commit comments