@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
39873987template  [[host_name(" kernel_rope_vision_f32" kernel_rope_vision_t  kernel_rope_vision<float >;
39883988template  [[host_name(" kernel_rope_vision_f16" kernel_rope_vision_t  kernel_rope_vision<half>;
39893989
3990+ typedef  void  (im2col_t )(
3991+         constant ggml_metal_kargs_im2col & args,
3992+         device const  float  * x,
3993+         device        char  * dst,
3994+         uint3 tgpig[[threadgroup_position_in_grid]],
3995+         uint3  tgpg[[threadgroups_per_grid]],
3996+         uint3 tpitg[[thread_position_in_threadgroup]],
3997+         uint3   ntg[[threads_per_threadgroup]]);
3998+ 
3999+ template  <typename  T>
4000+ kernel void  kernel_im2col (
4001+         constant ggml_metal_kargs_im2col & args,
4002+         device const  float  * x,
4003+         device        char  * dst,
4004+         uint3 tgpig[[threadgroup_position_in_grid]],
4005+         uint3  tgpg[[threadgroups_per_grid]],
4006+         uint3 tpitg[[thread_position_in_threadgroup]],
4007+         uint3   ntg[[threads_per_threadgroup]]) {
4008+ //     const int64_t IC = tgpg[0];
4009+     const  int64_t  OH = tgpg[1 ];
4010+     const  int64_t  OW = tgpg[2 ];
4011+ 
4012+     const  int64_t  KH = ntg[1 ];
4013+     const  int64_t  KW = ntg[2 ];
4014+ 
4015+           int64_t  in  = tpitg[0 ];
4016+     const  int64_t  ikh = tpitg[1 ];
4017+     const  int64_t  ikw = tpitg[2 ];
4018+ 
4019+     const  int64_t  iic = tgpig[0 ];
4020+     const  int64_t  ioh = tgpig[1 ];
4021+     const  int64_t  iow = tgpig[2 ];
4022+ 
4023+     const  int64_t  iiw = iow*args.s0  + ikw*args.d0  - args.p0 ;
4024+     const  int64_t  iih = ioh*args.s1  + ikh*args.d1  - args.p1 ;
4025+ 
4026+     int64_t  offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW  + (iic*(KH*KW) + ikh*KW + ikw);
4027+ 
4028+     device T * pdst = (device T *) (dst);
4029+ 
4030+     if  (iih < 0  || iih >= args.IH  || iiw < 0  || iiw >= args.IW ) {
4031+         while  (in < args.N ) {
4032+             pdst[offset_dst] = 0 .0f ;
4033+             offset_dst += ntg[0 ]*args.CHW *OH*OW;
4034+ 
4035+             in += ntg[0 ];
4036+         }
4037+     } else  {
4038+         int64_t  offset_src = in*args.ofs0  + iic*args.ofs1  + iih*args.IW  + iiw;
4039+ 
4040+         while  (in < args.N ) {
4041+             pdst[offset_dst] = x[offset_src];
4042+ 
4043+             offset_dst += ntg[0 ]*args.CHW *OH*OW;
4044+             offset_src += ntg[0 ]*args.ofs0 ;
4045+ 
4046+             in += ntg[0 ];
4047+         }
4048+     }
4049+ }
4050+ 
4051+ template  [[host_name(" kernel_im2col_f32" im2col_t  kernel_im2col<float >;
4052+ template  [[host_name(" kernel_im2col_f16" im2col_t  kernel_im2col<half>;
4053+ 
39904054//  TODO: obolete -- remove
3991- // typedef void (im2col_t )(
4055+ // typedef void (im2col_ext_t )(
39924056//         constant ggml_metal_kargs_im2col & args,
39934057//         device const float * x,
39944058//         device        char * dst,
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
39984062//         uint3   ntg[[threads_per_threadgroup]]);
39994063// 
40004064// template <typename T>
4001- // kernel void kernel_im2col (
4065+ // kernel void kernel_im2col_ext (
40024066//         constant ggml_metal_kargs_im2col & args,
40034067//         device const float * x,
40044068//         device        char * dst,
40054069//         uint3 tgpig[[threadgroup_position_in_grid]],
4006- //         uint3  tgpg[[threadgroups_per_grid]],
4070+ //         uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW 
40074071//         uint3 tpitg[[thread_position_in_threadgroup]],
4008- //         uint3   ntg[[threads_per_threadgroup]]) {
4009- // //    const int64_t IC = tgpg[0];
4010- //     const int64_t OH = tgpg[1];
4011- //     const int64_t OW = tgpg[2];
4072+ //         uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
4073+ //     const int64_t KHW = (int64_t)args.KHW;
40124074// 
4013- // //    const int64_t N  = ntg[0];
4014- //     const int64_t KH = ntg[1];
4015- //     const int64_t KW = ntg[2];
4075+ //     const int64_t d   = tgpig[0] / args.CHW;
4076+ //     const int64_t chw = tgpig[0] % args.CHW;
4077+ //     const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
4078+ //     const int64_t HW = tgpig[0] % KHW;
40164079// 
4017- //     const int64_t in  = tpitg[0];
4018- //     const int64_t ikh = tpitg[1];
4019- //     const int64_t ikw = tpitg[2];
4080+ //     const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4081+ //     if (tpitg_0 >= args.N) {
4082+ //         return;
4083+ //     }
40204084// 
4021- //     const int64_t iic = tgpig[0];
4022- //     const int64_t ioh = tgpig[1];
4023- //     const int64_t iow = tgpig[2];
4085+ //     const int64_t tpitg_1 = HW / args.KW;
4086+ //     const int64_t tpitg_2 = HW % args.KW;
40244087// 
4025- //     const int64_t iiw = iow* args.s0 + ikw* args.d0 - args.p0;
4026- //     const int64_t iih = ioh* args.s1 + ikh* args.d1 - args.p1;
4088+ //     const int64_t iiw = tgpig[2] *  args.s0 + tpitg_2 *  args.d0 - args.p0;
4089+ //     const int64_t iih = tgpig[1] *  args.s1 + tpitg_1 *  args.d1 - args.p1;
40274090// 
4028- //     const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4091+ //     const int64_t offset_dst =
4092+ //         (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4093+ //         (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
40294094// 
40304095//     device T * pdst = (device T *) (dst);
40314096// 
40324097//     if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
40334098//         pdst[offset_dst] = 0.0f;
40344099//     } else {
4035- //         const int64_t offset_src = in* args.ofs0 + iic*args.ofs1 + iih* args.IW + iiw ;
4036- //         pdst[offset_dst] = x[offset_src];
4100+ //         const int64_t offset_src = tpitg_0 *  args.ofs0 + tgpig_0 *  args.ofs1 ;
4101+ //         pdst[offset_dst] = x[offset_src + iih * args.IW + iiw ];
40374102//     }
40384103// }
40394104// 
4040- // template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4041- // template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4042- 
4043- typedef  void  (im2col_ext_t )(
4044-         constant ggml_metal_kargs_im2col & args,
4045-         device const  float  * x,
4046-         device        char  * dst,
4047-         uint3 tgpig[[threadgroup_position_in_grid]],
4048-         uint3  tgpg[[threadgroups_per_grid]],
4049-         uint3 tpitg[[thread_position_in_threadgroup]],
4050-         uint3   ntg[[threads_per_threadgroup]]);
4051- 
4052- template  <typename  T>
4053- kernel void  kernel_im2col_ext (
4054-         constant ggml_metal_kargs_im2col & args,
4055-         device const  float  * x,
4056-         device        char  * dst,
4057-         uint3 tgpig[[threadgroup_position_in_grid]],
4058-         uint3  tgpg[[threadgroups_per_grid]],      //  tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4059-         uint3 tpitg[[thread_position_in_threadgroup]],
4060-         uint3   ntg[[threads_per_threadgroup]]) {  //  [M, 1, 1]
4061-     const  int64_t  KHW = (int64_t )args.KHW ;
4062- 
4063-     const  int64_t  d   = tgpig[0 ] / args.CHW ;
4064-     const  int64_t  chw = tgpig[0 ] % args.CHW ;
4065-     const  int64_t  tgpig_0 = chw / KHW;  //  0 ~ (IC - 1)
4066-     const  int64_t  HW = tgpig[0 ] % KHW;
4067- 
4068-     const  int64_t  tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
4069-     if  (tpitg_0 >= args.N ) {
4070-         return ;
4071-     }
4072- 
4073-     const  int64_t  tpitg_1 = HW / args.KW ;
4074-     const  int64_t  tpitg_2 = HW % args.KW ;
4075- 
4076-     const  int64_t  iiw = tgpig[2 ] * args.s0  + tpitg_2 * args.d0  - args.p0 ;
4077-     const  int64_t  iih = tgpig[1 ] * args.s1  + tpitg_1 * args.d1  - args.p1 ;
4078- 
4079-     const  int64_t  offset_dst =
4080-         (tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * args.CHW  +
4081-         (tgpig_0 * KHW + tpitg_1 * args.KW  + tpitg_2);
4082- 
4083-     device T * pdst = (device T *) (dst);
4084- 
4085-     if  (iih < 0  || iih >= args.IH  || iiw < 0  || iiw >= args.IW ) {
4086-         pdst[offset_dst] = 0 .0f ;
4087-     } else  {
4088-         const  int64_t  offset_src = tpitg_0 * args.ofs0  + tgpig_0 * args.ofs1 ;
4089-         pdst[offset_dst] = x[offset_src + iih * args.IW  + iiw];
4090-     }
4091- }
4092- 
4093- template  [[host_name(" kernel_im2col_ext_f32" im2col_ext_t  kernel_im2col_ext<float >;
4094- template  [[host_name(" kernel_im2col_ext_f16" im2col_ext_t  kernel_im2col_ext<half>;
4105+ // template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4106+ // template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
40954107
40964108typedef  void  (conv_transpose_1d_t )(
40974109        constant ggml_metal_kargs_conv_transpose_1d & args,
0 commit comments