@@ -4146,6 +4146,120 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
41464146// template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
41474147// template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
41484148
4149+ template <typename TK>
4150+ kernel void kernel_conv_2d (
4151+ constant ggml_metal_kargs_conv_2d & args,
4152+ device const char * weights,
4153+ device const char * src,
4154+ device char * dst,
4155+ uint3 tgpig[[threadgroup_position_in_grid]],
4156+ uint3 tgpg[[threadgroups_per_grid]],
4157+ uint3 tpitg[[thread_position_in_threadgroup]],
4158+ uint3 ntg[[threads_per_threadgroup]]) {
4159+
4160+ const uint threads_per_tg = ntg.x * ntg.y * ntg.z ;
4161+ const uint tg_index = (tgpig.z * tgpg.y + tgpig.y ) * tgpg.x + tgpig.x ;
4162+ const uint local_thread = tpitg.z * (ntg.x * ntg.y ) + tpitg.y * ntg.x + tpitg.x ;
4163+ const uint thread_index = tg_index * threads_per_tg + local_thread;
4164+ const uint64_t total_threads = (uint64_t ) threads_per_tg * tgpg.x * tgpg.y * tgpg.z ;
4165+ const uint64_t total_outputs = (uint64_t ) args.N * args.OC * args.OH * args.OW ;
4166+
4167+ for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
4168+ uint64_t tmp = index;
4169+
4170+ const int32_t ow = tmp % args.OW ; tmp /= args.OW ;
4171+ const int32_t oh = tmp % args.OH ; tmp /= args.OH ;
4172+ const int32_t oc = tmp % args.OC ; tmp /= args.OC ;
4173+ const int32_t n = tmp;
4174+
4175+ float acc = 0 .0f ;
4176+
4177+ const int32_t base_x = ow*args.s0 - args.p0 ;
4178+ const int32_t base_y = oh*args.s1 - args.p1 ;
4179+
4180+ int32_t ky_start = 0 ;
4181+ if (base_y < 0 ) {
4182+ ky_start = (-base_y + args.d1 - 1 )/args.d1 ;
4183+ }
4184+ int32_t ky_end = args.KH ;
4185+ const int32_t y_max = args.IH - 1 - base_y;
4186+ if (y_max < 0 ) {
4187+ ky_end = ky_start;
4188+ } else if (base_y + (args.KH - 1 )*args.d1 >= args.IH ) {
4189+ ky_end = min (ky_end, y_max/args.d1 + 1 );
4190+ }
4191+
4192+ int32_t kx_start = 0 ;
4193+ if (base_x < 0 ) {
4194+ kx_start = (-base_x + args.d0 - 1 )/args.d0 ;
4195+ }
4196+ int32_t kx_end = args.KW ;
4197+ const int32_t x_max = args.IW - 1 - base_x;
4198+ if (x_max < 0 ) {
4199+ kx_end = kx_start;
4200+ } else if (base_x + (args.KW - 1 )*args.d0 >= args.IW ) {
4201+ kx_end = min (kx_end, x_max/args.d0 + 1 );
4202+ }
4203+
4204+ if (ky_start < ky_end && kx_start < kx_end) {
4205+ const uint64_t src_base_n = (uint64_t ) n * args.nb13 ;
4206+ const uint64_t w_base_oc = (uint64_t ) oc * args.nb03 ;
4207+
4208+ for (int32_t ic = 0 ; ic < args.IC ; ++ic) {
4209+ const uint64_t src_base_nc = src_base_n + (uint64_t ) ic * args.nb12 ;
4210+ const uint64_t w_base_ocic = w_base_oc + (uint64_t ) ic * args.nb02 ;
4211+
4212+ for (int32_t ky = ky_start; ky < ky_end; ++ky) {
4213+ const int32_t iy = base_y + ky*args.d1 ;
4214+ const uint64_t src_base_row = src_base_nc + (uint64_t ) iy * args.nb11 ;
4215+ const uint64_t w_base_row = w_base_ocic + (uint64_t ) ky * args.nb01 ;
4216+
4217+ for (int32_t kx = kx_start; kx < kx_end; ++kx) {
4218+ const int32_t ix = base_x + kx*args.d0 ;
4219+ const uint64_t src_offs = src_base_row + (uint64_t ) ix * args.nb10 ;
4220+ const uint64_t w_offs = w_base_row + (uint64_t ) kx * args.nb00 ;
4221+
4222+ const float x = *(device const float *)(src + src_offs);
4223+ const float w = (float ) (*(device const TK *)(weights + w_offs));
4224+
4225+ acc += x * w;
4226+ }
4227+ }
4228+ }
4229+ }
4230+
4231+ const uint64_t dst_offs =
4232+ (uint64_t ) n * args.nb3 +
4233+ (uint64_t ) oc * args.nb2 +
4234+ (uint64_t ) oh * args.nb1 +
4235+ (uint64_t ) ow * args.nb0 ;
4236+
4237+ *(device float *)(dst + dst_offs) = acc;
4238+ }
4239+ }
4240+
4241+ template [[host_name(" kernel_conv_2d_f32_f32" )]]
4242+ kernel void kernel_conv_2d<float >(
4243+ constant ggml_metal_kargs_conv_2d & args,
4244+ device const char * weights,
4245+ device const char * src,
4246+ device char * dst,
4247+ uint3 tgpig[[threadgroup_position_in_grid]],
4248+ uint3 tgpg[[threadgroups_per_grid]],
4249+ uint3 tpitg[[thread_position_in_threadgroup]],
4250+ uint3 ntg[[threads_per_threadgroup]]);
4251+
4252+ template [[host_name(" kernel_conv_2d_f16_f32" )]]
4253+ kernel void kernel_conv_2d<half>(
4254+ constant ggml_metal_kargs_conv_2d & args,
4255+ device const char * weights,
4256+ device const char * src,
4257+ device char * dst,
4258+ uint3 tgpig[[threadgroup_position_in_grid]],
4259+ uint3 tgpg[[threadgroups_per_grid]],
4260+ uint3 tpitg[[thread_position_in_threadgroup]],
4261+ uint3 ntg[[threads_per_threadgroup]]);
4262+
41494263typedef void (conv_transpose_1d_t )(
41504264 constant ggml_metal_kargs_conv_transpose_1d & args,
41514265 device const float * src0,
0 commit comments