@@ -1700,10 +1700,16 @@ kernel void kernel_ssm_scan_f32(
1700
1700
device const void * src5,
1701
1701
device const void * src6,
1702
1702
device float * dst,
1703
+ threadgroup float * shared [[threadgroup(0 )]],
1703
1704
constant ggml_metal_kargs_ssm_scan & args,
1704
- uint3 tgpig[[threadgroup_position_in_grid]],
1705
- uint3 tpitg[[thread_position_in_threadgroup]],
1706
- uint3 ntg[[threads_per_threadgroup]]) {
1705
+ uint3 tgpig[[threadgroup_position_in_grid]],
1706
+ uint3 tpitg[[thread_position_in_threadgroup]],
1707
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1708
+ ushort tiisg[[thread_index_in_simdgroup]],
1709
+ ushort sgptg[[simdgroups_per_threadgroup]],
1710
+ uint3 tgpg[[threadgroups_per_grid]]) {
1711
+
1712
+ const int64_t i0 = tpitg.x ;
1707
1713
const int64_t i1 = 0 ;
1708
1714
const int64_t ir = tgpig.x ; // current head
1709
1715
const int64_t i3 = tgpig.y ; // current seq
@@ -1718,37 +1724,85 @@ kernel void kernel_ssm_scan_f32(
1718
1724
const int64_t ng = args.n_group ;
1719
1725
const int64_t n_t = args.n_seq_tokens ;
1720
1726
1721
- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof ( float ) ;
1727
+ const int64_t s_off = args.s_off ;
1722
1728
1723
1729
device const int32_t * ids = (device const int32_t *) src6;
1724
1730
1725
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1726
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1731
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1732
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1733
+ const int64_t i = i0 + i1*nc;
1734
+ float s0 = s0_buff[i];
1735
+ float s = s_buff[i];
1736
+
1737
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 );
1738
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
1739
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
1740
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i3*args.nb43 );
1741
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i3*args.nb53 );
1742
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
1727
1743
1728
1744
for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1729
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1730
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1731
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {d_state, nh}
1732
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1733
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1734
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1745
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
1746
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
1747
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
1748
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
1749
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
1735
1750
1736
1751
const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
1737
1752
const float x_dt = x[0 ] * dt_soft_plus;
1738
- float sumf = 0 .0f ;
1739
1753
1740
- for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1741
- const int64_t i = i0 + i1*nc;
1742
- const float state = (s0[i] * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1743
- sumf += state * C[i0];
1744
- s[i] = state;
1745
- }
1754
+ const float state = (s0 * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1755
+ s = state;
1756
+
1757
+ // Parallel sum: This relies on the fact that this kernel will be
1758
+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1759
+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1760
+ // compute y = sum({state * C[i] for i in range(d_state)}).
1761
+ // To parallelize this effectively, we first use simd_sum over each SIMD
1762
+ // group to compute the sum of each SIMD group, then place the result in
1763
+ // the SIMD group's indexed bucket in the shared memory. We then sum
1764
+ // over the individual group sums to compute the final sum.
1765
+
1766
+ // Computed for each thread
1767
+ float sumf = state * C[i0];
1768
+
1769
+ // Sum the threads in the simd group => simd sum
1770
+ sumf = simd_sum (sumf);
1746
1771
1747
- y[0 ] = sumf;
1772
+ if (sgptg > 1 ) {
1773
+
1774
+ // Once per simd group, place the group sum into the shared buffer
1775
+ if (tiisg == 0 ) {
1776
+ shared[sgitg] = sumf;
1777
+ }
1778
+
1779
+ // Wait for all threads in the threadgroup to reach this point. This
1780
+ // ensures that all elements of the shared buffer are populated with the
1781
+ // sum of the individual simd groups.
1782
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1783
+
1784
+ // For simd group 0 at indices < num simd groups, extract the shared
1785
+ // simd sum
1786
+ sumf = 0 .0f ;
1787
+ if (sgitg == 0 ) {
1788
+ if (tiisg < sgptg) {
1789
+ sumf = shared[tiisg];
1790
+ }
1791
+ sumf = simd_sum (sumf);
1792
+ if (tiisg == 0 ) {
1793
+ y[0 ] = sumf;
1794
+ }
1795
+ }
1796
+ } else if (tiisg == 0 ) {
1797
+ y[0 ] = sumf;
1798
+ }
1748
1799
1749
1800
// recurse
1750
1801
s0 = s;
1751
1802
}
1803
+
1804
+ // Assign the final state to the output buffer
1805
+ s_buff[i] = s;
1752
1806
}
1753
1807
1754
1808
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
@@ -1770,6 +1824,7 @@ kernel void kernel_ssm_scan_f32_group(
1770
1824
ushort sgptg[[simdgroups_per_threadgroup]],
1771
1825
uint3 tgpg[[threadgroups_per_grid]]) {
1772
1826
1827
+ const int64_t i0 = tpitg.x ;
1773
1828
const int64_t i1 = tgpig.x ;
1774
1829
const int64_t ir = tgpig.y ; // current head
1775
1830
const int64_t i3 = tgpig.z ; // current seq
@@ -1790,7 +1845,7 @@ kernel void kernel_ssm_scan_f32_group(
1790
1845
1791
1846
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1792
1847
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793
- const int64_t i = tpitg. x + i1*nc;
1848
+ const int64_t i = i0 + i1*nc;
1794
1849
float s0 = s0_buff[i];
1795
1850
float s = s_buff[i];
1796
1851
@@ -1812,7 +1867,7 @@ kernel void kernel_ssm_scan_f32_group(
1812
1867
const float x_dt = x[0 ] * dt_soft_plus;
1813
1868
const float dA = exp (dt_soft_plus * A[0 ]);
1814
1869
1815
- const float state = (s0 * dA) + (B[tpitg. x ] * x_dt);
1870
+ const float state = (s0 * dA) + (B[i0 ] * x_dt);
1816
1871
s = state;
1817
1872
1818
1873
// Parallel sum: This relies on the fact that this kernel will be
@@ -1825,7 +1880,7 @@ kernel void kernel_ssm_scan_f32_group(
1825
1880
// over the individual group sums to compute the final sum.
1826
1881
1827
1882
// Computed for each thread
1828
- float sumf = state * C[tpitg. x ];
1883
+ float sumf = state * C[i0 ];
1829
1884
1830
1885
// Sum the threads in the simd group => simd sum
1831
1886
sumf = simd_sum (sumf);
0 commit comments