@@ -993,6 +993,7 @@ kernel void kernel_neg(
993993 dst[tpig] = -src0[tpig];
994994}
995995
996+ template <bool norm>
996997kernel void kernel_sum_rows (
997998 constant ggml_metal_kargs_sum_rows & args,
998999 device const float * src0,
@@ -1038,59 +1039,14 @@ kernel void kernel_sum_rows(
10381039 sumf = simd_sum (sumf);
10391040
10401041 if (tpitg.x == 0 ) {
1041- dst_row[0 ] = sumf;
1042+ dst_row[0 ] = norm ? sumf / args. ne00 : sumf;
10421043 }
10431044}
10441045
1045- // TODO: deduplicate with sum_rows
1046- kernel void kernel_mean (
1047- constant ggml_metal_kargs_sum_rows & args,
1048- device const float * src0,
1049- device float * dst,
1050- threadgroup float * shmem_f32 [[threadgroup(0 )]],
1051- uint3 tgpig[[threadgroup_position_in_grid]],
1052- ushort3 tpitg[[thread_position_in_threadgroup]],
1053- ushort sgitg[[simdgroup_index_in_threadgroup]],
1054- ushort tiisg[[thread_index_in_simdgroup]],
1055- ushort3 ntg[[threads_per_threadgroup]]) {
1056- int64_t i3 = tgpig.z ;
1057- int64_t i2 = tgpig.y ;
1058- int64_t i1 = tgpig.x ;
1059-
1060- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01 ) {
1061- return ;
1062- }
1063-
1064- if (sgitg == 0 ) {
1065- shmem_f32[tiisg] = 0 .0f ;
1066- }
1067-
1068- device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03 );
1069- device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3 );
1046+ typedef decltype (kernel_sum_rows<false >) kernel_sum_rows_t;
10701047
1071- float sumf = 0 ;
1072-
1073- for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1074- sumf += src_row[i0];
1075- }
1076-
1077- sumf = simd_sum (sumf);
1078-
1079- threadgroup_barrier (mem_flags::mem_threadgroup);
1080-
1081- if (tiisg == 0 ) {
1082- shmem_f32[sgitg] = sumf;
1083- }
1084-
1085- threadgroup_barrier (mem_flags::mem_threadgroup);
1086-
1087- sumf = shmem_f32[tiisg];
1088- sumf = simd_sum (sumf);
1089-
1090- if (tpitg.x == 0 ) {
1091- dst_row[0 ] = sumf/args.ne00 ;
1092- }
1093- }
1048+ template [[host_name(" kernel_sum_rows" )]] kernel kernel_sum_rows_t kernel_sum_rows<false >;
1049+ template [[host_name(" kernel_mean" )]] kernel kernel_sum_rows_t kernel_sum_rows<true >;
10941050
10951051template <typename T>
10961052kernel void kernel_soft_max (
0 commit comments