@@ -4146,6 +4146,7 @@ kernel void kernel_conv_transpose_2d(
41464146 device const T * src0,
41474147 device const float * src1,
41484148 device char * dst,
4149+ threadgroup float * shared_sum [[threadgroup(0 )]],
41494150 uint3 tgpig[[threadgroup_position_in_grid]],
41504151 uint3 tpitg[[thread_position_in_threadgroup]],
41514152 uint3 ntg[[threads_per_threadgroup]]) {
@@ -4182,7 +4183,6 @@ kernel void kernel_conv_transpose_2d(
41824183 v += (float )src0[kernel_idx] * src1[input_idx];
41834184 }
41844185
4185- threadgroup float shared_sum[64 ];
41864186 const uint tid = tpitg.y * ntg.x + tpitg.x ;
41874187 shared_sum[tid] = v;
41884188
@@ -4206,6 +4206,7 @@ kernel void kernel_conv_transpose_2d<float>(
42064206 device const float * src0,
42074207 device const float * src1,
42084208 device char * dst,
4209+ threadgroup float * shared_sum [[threadgroup(0 )]],
42094210 uint3 tgpig[[threadgroup_position_in_grid]],
42104211 uint3 tpitg[[thread_position_in_threadgroup]],
42114212 uint3 ntg[[threads_per_threadgroup]]);
@@ -4216,6 +4217,7 @@ kernel void kernel_conv_transpose_2d<half>(
42164217 device const half * src0,
42174218 device const float * src1,
42184219 device char * dst,
4220+ threadgroup float * shared_sum [[threadgroup(0 )]],
42194221 uint3 tgpig[[threadgroup_position_in_grid]],
42204222 uint3 tpitg[[thread_position_in_threadgroup]],
42214223 uint3 ntg[[threads_per_threadgroup]]);
0 commit comments