Skip to content

Commit 16bc059

Browse files
committed
feat: Parallel sum in SSM_CONV
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 80545ef commit 16bc059

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node(
29092909
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
29102910
[encoder setBytes:&args length:sizeof(args) atIndex:3];
29112911

2912-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2912+
const int64_t d_state = ne10;
2913+
2914+
// One shared memory bucket for each simd group in the threadgroup
2915+
if (d_state >= 32) {
2916+
const int64_t shmem_size = d_state / 32;
2917+
2918+
// The final simd_sum won't work if the number of simd groups is
2919+
// larger than the size of a single simd group. If this case is
2920+
// hit at some point, the logic in the second simd_sum could be
2921+
// expanded to handle this with one more sequential simd_sum to
2922+
// collapse simd group sums another time.
2923+
GGML_ASSERT(shmem_size <= 32);
2924+
2925+
// One thread pre element in d_state
2926+
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
2927+
2928+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
2929+
}
2930+
2931+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
29132932
} break;
29142933
case GGML_OP_SSM_SCAN:
29152934
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
16631663
device const void * src0,
16641664
device const void * src1,
16651665
device float * dst,
1666+
threadgroup float * shared [[threadgroup(0)]],
16661667
constant ggml_metal_kargs_ssm_conv & args,
1667-
uint3 tgpig[[threadgroup_position_in_grid]],
1668-
uint3 tpitg[[thread_position_in_threadgroup]],
1669-
uint3 ntg[[threads_per_threadgroup]]) {
1668+
uint3 tgpig[[threadgroup_position_in_grid]],
1669+
uint3 tpitg[[thread_position_in_threadgroup]],
1670+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1671+
ushort tiisg[[thread_index_in_simdgroup]],
1672+
ushort sgptg[[simdgroups_per_threadgroup]],
1673+
uint3 tgpg[[threadgroups_per_grid]]) {
1674+
1675+
const int64_t i0 = tpitg.x;
16701676
const int64_t ir = tgpig.x;
16711677
const int64_t i2 = tgpig.y;
16721678
const int64_t i3 = tgpig.z;
@@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32(
16811687
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
16821688
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
16831689

1684-
float sumf = 0.0f;
1690+
float sumf = s[i0] * c[i0];
16851691

1686-
for (int64_t i0 = 0; i0 < nc; ++i0) {
1687-
sumf += s[i0] * c[i0];
1688-
}
1692+
// Parallel sum: first sum over threads in simd group, then sum over simd
1693+
// group sums
1694+
sumf = simd_sum(sumf);
16891695

1690-
x[0] = sumf;
1696+
// If multiple simd groups per threadgroup, sum over simd group sums
1697+
if (sgptg > 1) {
1698+
if (tiisg == 0) {
1699+
shared[sgitg] = sumf;
1700+
}
1701+
threadgroup_barrier(mem_flags::mem_threadgroup);
1702+
sumf = 0.0f;
1703+
if (sgitg == 0) {
1704+
if (tiisg < sgptg) {
1705+
sumf = shared[tiisg];
1706+
}
1707+
sumf = simd_sum(sumf);
1708+
if (tiisg == 0) {
1709+
x[0] = sumf;
1710+
}
1711+
}
1712+
} else if (tiisg == 0) {
1713+
x[0] = sumf;
1714+
}
16911715
}
16921716

16931717
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part

0 commit comments

Comments
 (0)