@@ -1293,50 +1293,45 @@ kernel void kernel_norm(
12931293}
12941294
12951295kernel void kernel_rms_norm (
1296- device const void * src0 ,
1297- device float * dst ,
1298- constant int64_t & ne00 ,
1299- constant uint64_t & nb01 ,
1300- constant float & eps ,
1301- threadgroup float * buf [[threadgroup( 0 ) ]],
1302- uint tgpig[[threadgroup_position_in_grid ]],
1303- uint tpitg[[thread_position_in_threadgroup ]],
1304- uint sgitg[[simdgroup_index_in_threadgroup]],
1305- uint tiisg[[thread_index_in_simdgroup]],
1306- uint ntg[[threads_per_threadgroup]]) {
1307- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1296+ constant ggml_metal_kargs_rms_norm & args ,
1297+ device const char * src0 ,
1298+ device char * dst ,
1299+ threadgroup float * shmem_f32 [[threadgroup( 0 )]] ,
1300+ uint tgpig[[threadgroup_position_in_grid]] ,
1301+ ushort tpitg[[thread_position_in_threadgroup ]],
1302+ ushort sgitg[[simdgroup_index_in_threadgroup ]],
1303+ ushort tiisg[[thread_index_in_simdgroup ]],
1304+ ushort ntg[[threads_per_threadgroup]]) {
1305+ if (sgitg == 0 ) {
1306+ shmem_f32[tiisg] = 0 . 0f ;
1307+ }
13081308
1309- float4 sumf = 0 ;
1310- float all_sum = 0 ;
1309+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1310+
1311+ float sumf = 0 .0f ;
13111312
13121313 // parallel sum
1313- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1314- sumf += x[i00] * x[i00];
1314+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
1315+ sumf += dot ( x[i00], x[i00]) ;
13151316 }
1316- all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
1317- all_sum = simd_sum (all_sum);
1318- if (ntg > N_SIMDWIDTH) {
1319- if (sgitg == 0 ) {
1320- buf[tiisg] = 0 .0f ;
1321- }
1317+ sumf = simd_sum (sumf);
13221318
1323- threadgroup_barrier (mem_flags::mem_threadgroup);
1319+ threadgroup_barrier (mem_flags::mem_threadgroup);
13241320
1325- if (tiisg == 0 ) {
1326- buf [sgitg] = all_sum ;
1327- }
1321+ if (tiisg == 0 ) {
1322+ shmem_f32 [sgitg] = sumf ;
1323+ }
13281324
1329- threadgroup_barrier (mem_flags::mem_threadgroup);
1325+ threadgroup_barrier (mem_flags::mem_threadgroup);
13301326
1331- all_sum = buf[tiisg];
1332- all_sum = simd_sum (all_sum);
1333- }
1327+ sumf = shmem_f32[tiisg];
1328+ sumf = simd_sum (sumf);
13341329
1335- const float mean = all_sum/ ne00;
1336- const float scale = 1 .0f /sqrt (mean + eps);
1330+ const float mean = sumf/args. ne00 ;
1331+ const float scale = 1 .0f /sqrt (mean + args. eps );
13371332
1338- device float4 * y = (device float4 *) ( dst + tgpig*ne00) ;
1339- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1333+ device float4 * y = (device float4 *) dst + tgpig*args. ne00_4 ;
1334+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
13401335 y[i00] = x[i00] * scale;
13411336 }
13421337}
0 commit comments