@@ -1236,53 +1236,68 @@ kernel void kernel_ssm_scan_f32(
12361236}
12371237
12381238kernel void kernel_norm (
1239- device const void * src0,
1240- device float * dst,
1241- constant int64_t & ne00,
1242- constant uint64_t & nb01,
1243- constant float & eps,
1244- threadgroup float * sum [[threadgroup(0 )]],
1245- uint tgpig[[threadgroup_position_in_grid]],
1246- uint tpitg[[thread_position_in_threadgroup]],
1247- uint ntg[[threads_per_threadgroup]]) {
1248- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
1249- // MEAN
1250- // parallel sum
1251- sum[tpitg] = 0 .0f ;
1252- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1253- sum[tpitg] += x[i00];
1239+ constant ggml_metal_kargs_norm & args,
1240+ device const char * src0,
1241+ device char * dst,
1242+ threadgroup float * shmem_f32 [[threadgroup(0 )]],
1243+ uint tgpig[[threadgroup_position_in_grid]],
1244+ ushort tpitg[[thread_position_in_threadgroup]],
1245+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1246+ ushort tiisg[[thread_index_in_simdgroup]],
1247+ ushort ntg[[threads_per_threadgroup]]) {
1248+ if (sgitg == 0 ) {
1249+ shmem_f32[tiisg] = 0 .0f ;
12541250 }
1255- // reduce
1251+
1252+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1253+
1254+ float4 sumf4 (0 .0f );
1255+
1256+ float sumf = 0 .0f ;
1257+
1258+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1259+ sumf4 += x[i00];
1260+ }
1261+ sumf = sumf4[0 ] + sumf4[1 ] + sumf4[2 ] + sumf4[3 ];
1262+ sumf = simd_sum (sumf);
1263+
12561264 threadgroup_barrier (mem_flags::mem_threadgroup);
1257- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1258- if (tpitg < i) {
1259- sum[tpitg] += sum[tpitg + i];
1260- }
1261- threadgroup_barrier (mem_flags::mem_threadgroup);
1265+
1266+ if (tiisg == 0 ) {
1267+ shmem_f32[sgitg] = sumf;
12621268 }
1263- const float mean = sum[0 ] / ne00;
12641269
1265- // recenter and VARIANCE
12661270 threadgroup_barrier (mem_flags::mem_threadgroup);
1267- device float * y = dst + tgpig*ne00;
1268- sum[tpitg] = 0 .0f ;
1269- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1271+
1272+ sumf = shmem_f32[tiisg];
1273+ sumf = simd_sum (sumf);
1274+
1275+ const float mean = sumf/args.ne00 ;
1276+
1277+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4 ;
1278+
1279+ sumf = 0 .0f ;
1280+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
12701281 y[i00] = x[i00] - mean;
1271- sum[tpitg] += y[i00] * y[i00];
1282+ sumf += dot ( y[i00], y[i00]) ;
12721283 }
1284+ sumf = simd_sum (sumf);
12731285
1274- // reduce
12751286 threadgroup_barrier (mem_flags::mem_threadgroup);
1276- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1277- if (tpitg < i) {
1278- sum[tpitg] += sum[tpitg + i];
1279- }
1280- threadgroup_barrier (mem_flags::mem_threadgroup);
1287+
1288+ if (tiisg == 0 ) {
1289+ shmem_f32[sgitg] = sumf;
12811290 }
1282- const float variance = sum[0 ] / ne00;
12831291
1284- const float scale = 1 .0f /sqrt (variance + eps);
1285- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1292+ threadgroup_barrier (mem_flags::mem_threadgroup);
1293+
1294+ sumf = shmem_f32[tiisg];
1295+ sumf = simd_sum (sumf);
1296+
1297+ const float variance = sumf/args.ne00 ;
1298+
1299+ const float scale = 1 .0f /sqrt (variance + args.eps );
1300+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
12861301 y[i00] = y[i00] * scale;
12871302 }
12881303}
0 commit comments