@@ -4147,45 +4147,57 @@ kernel void kernel_conv_transpose_2d(
41474147 device const float * src1,
41484148 device char * dst,
41494149 uint3 tgpig[[threadgroup_position_in_grid]],
4150- uint3 tgpg[[threadgroups_per_grid]]) {
4150+ uint3 tpitg[[thread_position_in_threadgroup]],
4151+ uint3 ntg[[threads_per_threadgroup]]) {
41514152
41524153 const int64_t out_x = tgpig[0 ];
41534154 const int64_t out_y = tgpig[1 ];
41544155 const int64_t out_c = tgpig[2 ];
41554156
4157+ const int64_t kw = tpitg[0 ];
4158+ const int64_t kh = tpitg[1 ];
4159+
41564160 float v = 0 .0f ;
41574161
41584162 for (int64_t in_c = 0 ; in_c < args.IC ; in_c++) {
4159- for (int64_t kh = 0 ; kh < args.KH ; kh++) {
4163+ int64_t in_y = out_y - kh;
4164+
4165+ if (in_y < 0 || in_y % args.s0 ) continue ;
41604166
4161- int64_t in_y = out_y - kh ;
4167+ in_y /= args. s0 ;
41624168
4163- if (in_y < 0 || in_y % args.s0 ) continue ;
4169+ if (in_y >= args.IH ) continue ;
41644170
4165- in_y /= args. s0 ;
4171+ int64_t in_x = out_x - kw ;
41664172
4167- if (in_y >= args.IH ) continue ;
4173+ if (in_x < 0 || in_x % args.s0 ) continue ;
41684174
4169- for (int64_t kw = 0 ; kw < args.KW ; kw++) {
4170- int64_t in_x = out_x - kw;
4175+ in_x /= args.s0 ;
41714176
4172- if (in_x < 0 || in_x % args.s0 ) continue ;
4177+ if (in_x >= args.IW ) continue ;
41734178
4174- in_x /= args.s0 ;
4179+ const int64_t input_idx = (args.IW * args.IH ) * in_c + (args.IW ) * in_y + in_x;
4180+ const int64_t kernel_idx = (args.KH * args.KW * args.OC ) * in_c + (args.KH * args.KW ) * out_c + (args.KW ) * kh + kw;
41754181
4176- if (in_x >= args.IW ) continue ;
4182+ v += (float )src0[kernel_idx] * src1[input_idx];
4183+ }
41774184
4178- const int64_t input_idx = (args.IW * args.IH ) * in_c + (args.IW ) * in_y + in_x;
4179- const int64_t kernel_idx = (args.KH * args.KW * args.OC ) * in_c + (args.KH * args.KW ) * out_c + (args.KW ) * kh + kw;
4185+ threadgroup float shared_sum[64 ];
4186+ const uint tid = tpitg.y * ntg.x + tpitg.x ;
4187+ shared_sum[tid] = v;
41804188
4181- v += ( float )src0[kernel_idx] * src1[input_idx] ;
4189+ threadgroup_barrier (mem_flags::mem_threadgroup) ;
41824190
4183- }
4191+ if (tid == 0 ) {
4192+ float total = 0 .0f ;
4193+ const uint num_threads = ntg.x * ntg.y ;
4194+ for (uint i = 0 ; i < num_threads; i++) {
4195+ total += shared_sum[i];
41844196 }
4185- }
4186- device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2 );
41874197
4188- dst_ptr[0 ] = v;
4198+ device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2 );
4199+ dst_ptr[0 ] = total;
4200+ }
41894201}
41904202
41914203template [[host_name(" kernel_conv_transpose_2d_f32_f32" )]]
@@ -4195,7 +4207,8 @@ kernel void kernel_conv_transpose_2d<float>(
41954207 device const float * src1,
41964208 device char * dst,
41974209 uint3 tgpig[[threadgroup_position_in_grid]],
4198- uint3 tgpg[[threadgroups_per_grid]]);
4210+ uint3 tpitg[[thread_position_in_threadgroup]],
4211+ uint3 ntg[[threads_per_threadgroup]]);
41994212
42004213template [[host_name(" kernel_conv_transpose_2d_f16_f32" )]]
42014214kernel void kernel_conv_transpose_2d<half>(
@@ -4204,7 +4217,8 @@ kernel void kernel_conv_transpose_2d<half>(
42044217 device const float * src1,
42054218 device char * dst,
42064219 uint3 tgpig[[threadgroup_position_in_grid]],
4207- uint3 tgpg[[threadgroups_per_grid]]);
4220+ uint3 tpitg[[thread_position_in_threadgroup]],
4221+ uint3 ntg[[threads_per_threadgroup]]);
42084222
42094223kernel void kernel_upscale_f32 (
42104224 constant ggml_metal_kargs_upscale & args,
0 commit comments