@@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32(
18231823 device const void * src5,
18241824 device const void * src6,
18251825 device float * dst,
1826+ threadgroup float * shared [[threadgroup(0 )]],
18261827 constant ggml_metal_kargs_ssm_scan & args,
1827- uint3 tgpig[[threadgroup_position_in_grid]],
1828- uint3 tpitg[[thread_position_in_threadgroup]],
1829- uint3 ntg[[threads_per_threadgroup]]) {
1828+ uint3 tgpig[[threadgroup_position_in_grid]],
1829+ uint3 tpitg[[thread_position_in_threadgroup]],
1830+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1831+ ushort tiisg[[thread_index_in_simdgroup]],
1832+ ushort sgptg[[simdgroups_per_threadgroup]],
1833+ uint3 tgpg[[threadgroups_per_grid]]) {
1834+
1835+ const int64_t i0 = tpitg.x ;
18301836 const int64_t i1 = 0 ;
18311837 const int64_t ir = tgpig.x ; // current head
18321838 const int64_t i3 = tgpig.y ; // current seq
@@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32(
18411847 const int64_t ng = args.n_group ;
18421848 const int64_t n_t = args.n_seq_tokens ;
18431849
1844- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof ( float ) ;
1850+ const int64_t s_off = args.s_off ;
18451851
18461852 device const int32_t * ids = (device const int32_t *) src6;
18471853
1848- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1849- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1854+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1855+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1856+ const int64_t i = i0 + i1*nc;
1857+ float s0 = s0_buff[i];
1858+ float s = s_buff[i];
1859+
1860+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 );
1861+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
1862+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
1863+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i3*args.nb43 );
1864+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i3*args.nb53 );
1865+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
18501866
18511867 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1852- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1853- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1854- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {d_state, nh}
1855- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1856- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1857- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1868+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
1869+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
1870+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
1871+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
1872+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
18581873
18591874 const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
18601875 const float x_dt = x[0 ] * dt_soft_plus;
1861- float sumf = 0 .0f ;
18621876
1863- for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1864- const int64_t i = i0 + i1*nc;
1865- const float state = (s0[i] * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1866- sumf += state * C[i0];
1867- s[i] = state;
1868- }
1877+ const float state = (s0 * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1878+ s = state;
1879+
1880+ // Parallel sum: This relies on the fact that this kernel will be
1881+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1882+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1883+ // compute y = sum({state * C[i] for i in range(d_state)}).
1884+ // To parallelize this effectively, we first use simd_sum over each SIMD
1885+ // group to compute the sum of each SIMD group, then place the result in
1886+ // the SIMD group's indexed bucket in the shared memory. We then sum
1887+ // over the individual group sums to compute the final sum.
1888+
1889+ // Computed for each thread
1890+ float sumf = state * C[i0];
18691891
1870- y[0 ] = sumf;
1892+ // Sum the threads in the simd group => simd sum
1893+ sumf = simd_sum (sumf);
1894+
1895+ if (sgptg > 1 ) {
1896+
1897+ // Once per simd group, place the group sum into the shared buffer
1898+ if (tiisg == 0 ) {
1899+ shared[sgitg] = sumf;
1900+ }
1901+
1902+ // Wait for all threads in the threadgroup to reach this point. This
1903+ // ensures that all elements of the shared buffer are populated with the
1904+ // sum of the individual simd groups.
1905+ threadgroup_barrier (mem_flags::mem_threadgroup);
1906+
1907+ // For simd group 0 at indices < num simd groups, extract the shared
1908+ // simd sum
1909+ sumf = 0 .0f ;
1910+ if (sgitg == 0 ) {
1911+ if (tiisg < sgptg) {
1912+ sumf = shared[tiisg];
1913+ }
1914+ sumf = simd_sum (sumf);
1915+ if (tiisg == 0 ) {
1916+ y[0 ] = sumf;
1917+ }
1918+ }
1919+ } else if (tiisg == 0 ) {
1920+ y[0 ] = sumf;
1921+ }
18711922
18721923 // recurse
18731924 s0 = s;
18741925 }
1926+
1927+ // Assign the final state to the output buffer
1928+ s_buff[i] = s;
18751929}
18761930
18771931// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1878- // TODO: optimize (e.g. by parallelizing over d_state)
18791932kernel void kernel_ssm_scan_f32_group (
18801933 device const void * src0,
18811934 device const void * src1,
@@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group(
18851938 device const void * src5,
18861939 device const void * src6,
18871940 device float * dst,
1941+ threadgroup float * shared [[threadgroup(0 )]],
18881942 constant ggml_metal_kargs_ssm_scan & args,
1889- uint3 tgpig[[threadgroup_position_in_grid]],
1890- uint3 tpitg[[thread_position_in_threadgroup]],
1891- uint3 ntg[[threads_per_threadgroup]]) {
1943+ uint3 tgpig[[threadgroup_position_in_grid]],
1944+ uint3 tpitg[[thread_position_in_threadgroup]],
1945+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1946+ ushort tiisg[[thread_index_in_simdgroup]],
1947+ ushort sgptg[[simdgroups_per_threadgroup]],
1948+ uint3 tgpg[[threadgroups_per_grid]]) {
1949+
1950+ const int64_t i0 = tpitg.x ;
18921951 const int64_t i1 = tgpig.x ;
18931952 const int64_t ir = tgpig.y ; // current head
18941953 const int64_t i3 = tgpig.z ; // current seq
@@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group(
19031962 const int64_t ng = args.n_group ;
19041963 const int64_t n_t = args.n_seq_tokens ;
19051964
1906- const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof ( float ) ;
1965+ const int64_t s_off = args.s_off ;
19071966
19081967 device const int32_t * ids = (device const int32_t *) src6;
19091968
1910- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1911- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1969+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1970+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1971+ const int64_t i = i0 + i1*nc;
1972+ float s0 = s0_buff[i];
1973+ float s = s_buff[i];
1974+
1975+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1976+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
1977+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
1978+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i3*args.nb43 );
1979+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i3*args.nb53 );
1980+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
19121981
19131982 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1914- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1915- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1916- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1917- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1918- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1919- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1983+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
1984+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
1985+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
1986+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
1987+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
19201988
19211989 const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
19221990 const float x_dt = x[0 ] * dt_soft_plus;
19231991 const float dA = exp (dt_soft_plus * A[0 ]);
1924- float sumf = 0 .0f ;
19251992
1926- for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1927- const int64_t i = i0 + i1*nc;
1928- const float state = (s0[i] * dA) + (B[i0] * x_dt);
1929- sumf += state * C[i0];
1930- s[i] = state;
1993+ const float state = (s0 * dA) + (B[i0] * x_dt);
1994+ s = state;
1995+
1996+ // Parallel sum: This relies on the fact that this kernel will be
1997+ // dispatched with each threadgroup having (d_state, 1, 1) threads which
1998+ // are subdivided into SIMD groups of size `sgptg`. The goal is to
1999+ // compute y = sum({state * C[i] for i in range(d_state)}).
2000+ // To parallelize this effectively, we first use simd_sum over each SIMD
2001+ // group to compute the sum of each SIMD group, then place the result in
2002+ // the SIMD group's indexed bucket in the shared memory. We then sum
2003+ // over the individual group sums to compute the final sum.
2004+
2005+ // Computed for each thread
2006+ float sumf = state * C[i0];
2007+
2008+ // Sum the threads in the simd group => simd sum
2009+ sumf = simd_sum (sumf);
2010+
2011+ // Once per simd group, place the group sum into the shared buffer
2012+ if (tiisg == 0 ) {
2013+ shared[sgitg] = sumf;
19312014 }
19322015
1933- y[0 ] = sumf;
2016+ // Wait for all threads in the threadgroup to reach this point. This
2017+ // ensures that all elements of the shared buffer are populated with the
2018+ // sum of the individual simd groups.
2019+ threadgroup_barrier (mem_flags::mem_threadgroup);
2020+
2021+ // For simd group 0 at indices < num simd groups, extract the shared
2022+ // simd sum
2023+ sumf = 0 .0f ;
2024+ if (sgitg == 0 ) {
2025+ if (tiisg < sgptg) {
2026+ sumf = shared[tiisg];
2027+ }
2028+ sumf = simd_sum (sumf);
2029+ if (tiisg == 0 ) {
2030+ y[0 ] = sumf;
2031+ }
2032+ }
19342033
19352034 // recurse
19362035 s0 = s;
19372036 }
2037+
2038+ // Assign the final state to the output buffer
2039+ s_buff[i] = s;
19382040}
19392041
19402042kernel void kernel_rwkv_wkv6_f32 (
0 commit comments