@@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group(
17671767        uint3  tpitg[[thread_position_in_threadgroup]],
17681768        ushort sgitg[[simdgroup_index_in_threadgroup]],
17691769        ushort tiisg[[thread_index_in_simdgroup]],
1770-         uint3    ntg[[threads_per_threadgroup]]) {
1770+         ushort sgptg[[simdgroups_per_threadgroup]],
1771+         uint3   tgpg[[threadgroups_per_grid]]) {
17711772
17721773    const  int64_t  i1 = tgpig.x ;
17731774    const  int64_t  ir = tgpig.y ; //  current head
@@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group(
18021803        const  float  x_dt = x[0 ] * dt_soft_plus;
18031804        const  float  dA = exp (dt_soft_plus * A[0 ]);
18041805
1805-         threadgroup_barrier (mem_flags::mem_threadgroup);
1806- 
1807-         float  sumf = 0 .0f ;
1808- 
18091806        const  int64_t  i = tpitg.x  + i1*nc;
18101807        const  float  state = (s0[i] * dA) + (B[tpitg.x ] * x_dt);
1811-         sumf += state * C[tpitg.x ];
18121808        s[i] = state;
18131809
1814-         sumf = simd_sum (sumf);
1810+         //  Parallel sum: This relies on the fact that this kernel will be
1811+         //  dispatched with each threadgroup having (d_state, 1, 1) threads which
1812+         //  are subdivided into SIMD groups of size `sgptg`. The goal is to
1813+         //  compute y = sum({state * C[i] for i in range(d_state)}).
1814+         //  To parallelize this effectively, we first use simd_sum over each SIMD
1815+         //  group to compute the sum of each SIMD group, then place the result in
1816+         //  the SIMD group's indexed bucket in the shared memory. We then sum
1817+         //  over the individual group sums to compute the final sum.
18151818
1816-         threadgroup_barrier (mem_flags::mem_threadgroup);
1819+         //  Computed for each thread
1820+         float  sumf = state * C[tpitg.x ];
18171821
1818-         //  Use the shared buffer to hold the sum of each simd group
1822+         //  Sum the threads in the simd group => simd sum
1823+         sumf = simd_sum (sumf);
1824+ 
1825+         //  Once per simd group, place the group sum into the shared buffer
18191826        if  (tiisg == 0 ) {
18201827            shared[sgitg] = sumf;
18211828        }
18221829
1830+         //  Wait for all threads in the threadgroup to reach this point. This
1831+         //  ensures that all elements of the shared buffer are populated with the
1832+         //  sum of the individual simd groups.
18231833        threadgroup_barrier (mem_flags::mem_threadgroup);
18241834
1825-         //  Sum the simd buckets
1826-         sumf = shared[tiisg];
1827-         sumf = simd_sum (sumf);
1835+         //  Sum the simd buckets => threadgroup sum
1836+         sumf = 0 .0f ;
1837+         for  (int64_t  i0 = 0 ; i0 < sgptg; ++i0) {
1838+             sumf += shared[i0];
1839+         }
1840+ 
1841+         threadgroup_barrier (mem_flags::mem_threadgroup);
18281842
18291843        y[0 ] = sumf;
18301844
0 commit comments