@@ -1288,50 +1288,45 @@ kernel void kernel_norm(
12881288}
12891289
12901290kernel void kernel_rms_norm (
1291- device const void * src0 ,
1292- device float * dst ,
1293- constant int64_t & ne00 ,
1294- constant uint64_t & nb01 ,
1295- constant float & eps ,
1296- threadgroup float * buf [[threadgroup( 0 ) ]],
1297- uint tgpig[[threadgroup_position_in_grid ]],
1298- uint tpitg[[thread_position_in_threadgroup ]],
1299- uint sgitg[[simdgroup_index_in_threadgroup]],
1300- uint tiisg[[thread_index_in_simdgroup]],
1301- uint ntg[[threads_per_threadgroup]]) {
1302- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1291+ constant ggml_metal_kargs_rms_norm & args ,
1292+ device const char * src0 ,
1293+ device char * dst ,
1294+ threadgroup float * shmem_f32 [[threadgroup( 0 )]] ,
1295+ uint tgpig[[threadgroup_position_in_grid]] ,
1296+ ushort tpitg[[thread_position_in_threadgroup ]],
1297+ ushort sgitg[[simdgroup_index_in_threadgroup ]],
1298+ ushort tiisg[[thread_index_in_simdgroup ]],
1299+ ushort ntg[[threads_per_threadgroup]]) {
1300+ if (sgitg == 0 ) {
1301+ shmem_f32[tiisg] = 0 . 0f ;
1302+ }
13031303
1304- float4 sumf = 0 ;
1305- float all_sum = 0 ;
1304+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1305+
1306+ float sumf = 0 .0f ;
13061307
13071308 // parallel sum
1308- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1309- sumf += x[i00] * x[i00];
1309+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
1310+ sumf += dot ( x[i00], x[i00]) ;
13101311 }
1311- all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
1312- all_sum = simd_sum (all_sum);
1313- if (ntg > N_SIMDWIDTH) {
1314- if (sgitg == 0 ) {
1315- buf[tiisg] = 0 .0f ;
1316- }
1312+ sumf = simd_sum (sumf);
13171313
1318- threadgroup_barrier (mem_flags::mem_threadgroup);
1314+ threadgroup_barrier (mem_flags::mem_threadgroup);
13191315
1320- if (tiisg == 0 ) {
1321- buf [sgitg] = all_sum ;
1322- }
1316+ if (tiisg == 0 ) {
1317+ shmem_f32 [sgitg] = sumf ;
1318+ }
13231319
1324- threadgroup_barrier (mem_flags::mem_threadgroup);
1320+ threadgroup_barrier (mem_flags::mem_threadgroup);
13251321
1326- all_sum = buf[tiisg];
1327- all_sum = simd_sum (all_sum);
1328- }
1322+ sumf = shmem_f32[tiisg];
1323+ sumf = simd_sum (sumf);
13291324
1330- const float mean = all_sum/ ne00;
1331- const float scale = 1 .0f /sqrt (mean + eps);
1325+ const float mean = sumf/args. ne00 ;
1326+ const float scale = 1 .0f /sqrt (mean + args. eps );
13321327
1333- device float4 * y = (device float4 *) ( dst + tgpig*ne00) ;
1334- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1328+ device float4 * y = (device float4 *) dst + tgpig*args. ne00_4 ;
1329+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
13351330 y[i00] = x[i00] * scale;
13361331 }
13371332}
0 commit comments