@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
41794179 uint3 tgpig[[threadgroup_position_in_grid]],
41804180 uint3 tgpg[[threadgroups_per_grid]]);
41814181
4182+
4183+ typedef void (conv_transpose_2d_t )(
4184+ constant ggml_metal_kargs_conv_transpose_2d & args,
4185+ device const float * src0,
4186+ device const float * src1,
4187+ device char * dst,
4188+ uint3 tgpig[[threadgroup_position_in_grid]],
4189+ uint3 tgpg[[threadgroups_per_grid]]);
4190+
4191+ template <typename T>
4192+ kernel void kernel_conv_transpose_2d (
4193+ constant ggml_metal_kargs_conv_transpose_2d & args,
4194+ device const T * src0,
4195+ device const float * src1,
4196+ device char * dst,
4197+ threadgroup float * shared_sum [[threadgroup(0 )]],
4198+ uint3 tgpig[[threadgroup_position_in_grid]],
4199+ uint3 tpitg[[thread_position_in_threadgroup]],
4200+ uint3 ntg[[threads_per_threadgroup]]) {
4201+
4202+ const int64_t out_x = tgpig[0 ];
4203+ const int64_t out_y = tgpig[1 ];
4204+ const int64_t out_c = tgpig[2 ];
4205+
4206+ const int64_t kw = tpitg[0 ];
4207+ const int64_t kh = tpitg[1 ];
4208+
4209+ float v = 0 .0f ;
4210+
4211+ for (int64_t in_c = 0 ; in_c < args.IC ; in_c++) {
4212+ int64_t in_y = out_y - kh;
4213+
4214+ if (in_y < 0 || in_y % args.s0 ) continue ;
4215+
4216+ in_y /= args.s0 ;
4217+
4218+ if (in_y >= args.IH ) continue ;
4219+
4220+ int64_t in_x = out_x - kw;
4221+
4222+ if (in_x < 0 || in_x % args.s0 ) continue ;
4223+
4224+ in_x /= args.s0 ;
4225+
4226+ if (in_x >= args.IW ) continue ;
4227+
4228+ const int64_t input_idx = (args.IW * args.IH ) * in_c + (args.IW ) * in_y + in_x;
4229+ const int64_t kernel_idx = (args.KH * args.KW * args.OC ) * in_c + (args.KH * args.KW ) * out_c + (args.KW ) * kh + kw;
4230+
4231+ v += (float )src0[kernel_idx] * src1[input_idx];
4232+ }
4233+
4234+ const uint tid = tpitg.y * ntg.x + tpitg.x ;
4235+ shared_sum[tid] = v;
4236+
4237+ threadgroup_barrier (mem_flags::mem_threadgroup);
4238+
4239+ if (tid == 0 ) {
4240+ float total = 0 .0f ;
4241+ const uint num_threads = ntg.x * ntg.y ;
4242+ for (uint i = 0 ; i < num_threads; i++) {
4243+ total += shared_sum[i];
4244+ }
4245+
4246+ device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2 );
4247+ dst_ptr[0 ] = total;
4248+ }
4249+ }
4250+
4251+ template [[host_name(" kernel_conv_transpose_2d_f32_f32" )]]
4252+ kernel void kernel_conv_transpose_2d<float >(
4253+ constant ggml_metal_kargs_conv_transpose_2d & args,
4254+ device const float * src0,
4255+ device const float * src1,
4256+ device char * dst,
4257+ threadgroup float * shared_sum [[threadgroup(0 )]],
4258+ uint3 tgpig[[threadgroup_position_in_grid]],
4259+ uint3 tpitg[[thread_position_in_threadgroup]],
4260+ uint3 ntg[[threads_per_threadgroup]]);
4261+
4262+ template [[host_name(" kernel_conv_transpose_2d_f16_f32" )]]
4263+ kernel void kernel_conv_transpose_2d<half>(
4264+ constant ggml_metal_kargs_conv_transpose_2d & args,
4265+ device const half * src0,
4266+ device const float * src1,
4267+ device char * dst,
4268+ threadgroup float * shared_sum [[threadgroup(0 )]],
4269+ uint3 tgpig[[threadgroup_position_in_grid]],
4270+ uint3 tpitg[[thread_position_in_threadgroup]],
4271+ uint3 ntg[[threads_per_threadgroup]]);
4272+
41824273kernel void kernel_upscale_f32 (
41834274 constant ggml_metal_kargs_upscale & args,
41844275 device const char * src0,
0 commit comments