Skip to content

Commit fddb313

Browse files
committed
cont : dedup implementation
ggml-ci
1 parent b65122c commit fddb313

File tree

1 file changed

+5
-49
lines changed

1 file changed

+5
-49
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,7 @@ kernel void kernel_neg(
993993
dst[tpig] = -src0[tpig];
994994
}
995995

996+
template <bool norm>
996997
kernel void kernel_sum_rows(
997998
constant ggml_metal_kargs_sum_rows & args,
998999
device const float * src0,
@@ -1038,59 +1039,14 @@ kernel void kernel_sum_rows(
10381039
sumf = simd_sum(sumf);
10391040

10401041
if (tpitg.x == 0) {
1041-
dst_row[0] = sumf;
1042+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
10421043
}
10431044
}
10441045

1045-
// TODO: deduplicate with sum_rows
1046-
kernel void kernel_mean(
1047-
constant ggml_metal_kargs_sum_rows & args,
1048-
device const float * src0,
1049-
device float * dst,
1050-
threadgroup float * shmem_f32 [[threadgroup(0)]],
1051-
uint3 tgpig[[threadgroup_position_in_grid]],
1052-
ushort3 tpitg[[thread_position_in_threadgroup]],
1053-
ushort sgitg[[simdgroup_index_in_threadgroup]],
1054-
ushort tiisg[[thread_index_in_simdgroup]],
1055-
ushort3 ntg[[threads_per_threadgroup]]) {
1056-
int64_t i3 = tgpig.z;
1057-
int64_t i2 = tgpig.y;
1058-
int64_t i1 = tgpig.x;
1059-
1060-
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1061-
return;
1062-
}
1063-
1064-
if (sgitg == 0) {
1065-
shmem_f32[tiisg] = 0.0f;
1066-
}
1067-
1068-
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1069-
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1046+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
10701047

1071-
float sumf = 0;
1072-
1073-
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1074-
sumf += src_row[i0];
1075-
}
1076-
1077-
sumf = simd_sum(sumf);
1078-
1079-
threadgroup_barrier(mem_flags::mem_threadgroup);
1080-
1081-
if (tiisg == 0) {
1082-
shmem_f32[sgitg] = sumf;
1083-
}
1084-
1085-
threadgroup_barrier(mem_flags::mem_threadgroup);
1086-
1087-
sumf = shmem_f32[tiisg];
1088-
sumf = simd_sum(sumf);
1089-
1090-
if (tpitg.x == 0) {
1091-
dst_row[0] = sumf/args.ne00;
1092-
}
1093-
}
1048+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1049+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
10941050

10951051
template<typename T>
10961052
kernel void kernel_soft_max(

0 commit comments

Comments
 (0)