@@ -2075,105 +2075,88 @@ kernel void kernel_ssm_scan_f32(
20752075 device const void * src6,
20762076 device float * dst,
20772077 threadgroup float * shared [[threadgroup(0 )]],
2078- uint3 tgpig[[threadgroup_position_in_grid]],
2079- uint3 tpitg[[thread_position_in_threadgroup]],
2080- ushort sgitg[[simdgroup_index_in_threadgroup]],
2081- ushort tiisg[[thread_index_in_simdgroup]],
2082- ushort sgptg[[simdgroups_per_threadgroup]],
2083- uint3 tgpg[[threadgroups_per_grid]]) {
2078+ uint3 tgpig[[threadgroup_position_in_grid]],
2079+ ushort3 tpitg[[thread_position_in_threadgroup]],
2080+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2081+ ushort tiisg[[thread_index_in_simdgroup]],
2082+ ushort sgptg[[simdgroups_per_threadgroup]],
2083+ uint3 tgpg[[threadgroups_per_grid]]) {
2084+ constexpr short NW = N_SIMDWIDTH;
20842085
20852086 shared[tpitg.x ] = 0 .0f ;
20862087
2087- const int64_t i0 = tpitg.x ;
2088- const int64_t i1 = tgpig.x ;
2089- const int64_t ir = tgpig.y ; // current head
2090- const int64_t i3 = tgpig.z ; // current seq
2091-
2092- const uint64_t nb00 = sizeof (float );
2093- const uint64_t nb10 = sizeof (float );
2094- const uint64_t nb20 = sizeof (float );
2088+ const int32_t i0 = tpitg.x ;
2089+ const int32_t i1 = tgpig.x ;
2090+ const int32_t ir = tgpig.y ; // current head
2091+ const int32_t i3 = tgpig.z ; // current seq
20952092
2096- const int64_t nc = args.d_state ;
2097- const int64_t nr = args.d_inner ;
2098- const int64_t nh = args.n_head ;
2099- const int64_t ng = args.n_group ;
2100- const int64_t n_t = args.n_seq_tokens ;
2093+ const int32_t nc = args.d_state ;
2094+ const int32_t nr = args.d_inner ;
2095+ const int32_t nh = args.n_head ;
2096+ const int32_t ng = args.n_group ;
2097+ const int32_t n_t = args.n_seq_tokens ;
21012098
2102- const int64_t s_off = args.s_off ;
2099+ const int32_t s_off = args.s_off ;
21032100
21042101 device const int32_t * ids = (device const int32_t *) src6;
21052102
21062103 device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
21072104 device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2108- const int64_t i = i0 + i1*nc;
2109- const int64_t g = ir / (nh / ng); // repeat_interleave
2105+
2106+ const int32_t i = i0 + i1*nc;
2107+ const int32_t g = ir / (nh / ng); // repeat_interleave
2108+
21102109 float s0 = s0_buff[i];
21112110 float s = 0 .0f ;
21122111
21132112 device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {ne30, nh}
2113+
21142114 const float A0 = A[i0%args.ne30 ];
21152115
2116- device const char * x_block = ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
2117- device const char * dt_block = ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
2118- device const char * B_block = ((device const char *) src4 + g*args.nb41 + i3*args.nb43 );
2119- device const char * C_block = ((device const char *) src5 + g*args.nb51 + i3*args.nb53 );
2120- device char * y_block = ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
2116+ device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13 ); // {dim, nh, nt, ns}
2117+ device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22 ); // {nh, nt, ns}
2118+ device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43 ); // {d_state, ng, nt, ns}
2119+ device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53 ); // {d_state, ng, nt, ns}
21212120
2122- threadgroup_barrier (mem_flags::mem_threadgroup);
2121+ device float * y = dst + (i1 + ir*(nr) + i3*( n_t *nh*nr)); // {dim, nh, nt, ns}
21232122
2124- for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
2125- device const float * x = (device const float *) (x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
2126- device const float * dt = (device const float *) (dt_block + i2*args.nb21 ); // {nh, nt, ns}
2127- device const float * B = (device const float *) (B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
2128- device const float * C = (device const float *) (C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
2129- device float * y = (device float *) (y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2123+ for (int i2 = 0 ; i2 < n_t ; i2 += sgptg) {
2124+ threadgroup_barrier (mem_flags::mem_threadgroup);
21302125
2131- const float dt0 = dt[0 ];
2132- const float dt_soft_plus = dt0 <= 20 .0f ? log (1 .0f + exp (dt0)) : dt0;
2133- const float x_dt = x[0 ] * dt_soft_plus;
2134- const float dA = exp (dt_soft_plus * A0);
2126+ for (int t = 0 ; t < sgptg && i2 + t < n_t ; t++) {
2127+ const float dt0 = dt[0 ];
2128+ const float dtsp = dt0 <= 20 .0f ? log (1 .0f + exp (dt0)) : dt0;
2129+ const float x_dt = x[0 ] * dtsp;
2130+ const float dA = exp (dtsp * A0);
21352131
2136- s = (s0 * dA) + (B[i0] * x_dt);
2132+ s = (s0 * dA) + (B[i0] * x_dt);
21372133
2138- // Parallel sum: This relies on the fact that this kernel will be
2139- // dispatched with each threadgroup having (d_state, 1, 1) threads which
2140- // are subdivided into SIMD groups of size `sgptg`. The goal is to
2141- // compute y = sum({state * C[i] for i in range(d_state)}).
2142- // To parallelize this effectively, we first use simd_sum over each SIMD
2143- // group to compute the sum of each SIMD group, then place the result in
2144- // the SIMD group's indexed bucket in the shared memory. We then sum
2145- // over the individual group sums to compute the final sum.
2134+ const float sumf = simd_sum (s * C[i0]);
21462135
2147- // Computed for each thread
2148- float sumf = s * C[i0];
2136+ if (tiisg == 0 ) {
2137+ shared[t*NW + sgitg] = sumf;
2138+ }
21492139
2150- // Sum the threads in the simd group => simd sum
2151- sumf = simd_sum (sumf) ;
2140+ // recurse
2141+ s0 = s ;
21522142
2153- // Once per simd group, place the group sum into the shared buffer
2154- if (tiisg == 0 ) {
2155- shared[sgitg] = sumf;
2143+ x += args.ns12 ;
2144+ dt += args.ns21 ;
2145+ B += args.ns42 ;
2146+ C += args.ns52 ;
21562147 }
21572148
2158- // Wait for all threads in the threadgroup to reach this point. This
2159- // ensures that all elements of the shared buffer are populated with the
2160- // sum of the individual simd groups.
21612149 threadgroup_barrier (mem_flags::mem_threadgroup);
21622150
2163- // For simd group 0 at indices < num simd groups, extract the shared
2164- // simd sum
2165- if (sgitg == 0 ) {
2166- sumf = simd_sum (shared[tiisg]);
2167- if (tiisg == 0 ) {
2168- y[0 ] = sumf;
2169- }
2151+ const float sumf = simd_sum (shared[sgitg*NW + tiisg]);
2152+
2153+ if (tiisg == 0 && i2 + sgitg < n_t ) {
2154+ y[sgitg*nh*nr] = sumf;
21702155 }
21712156
2172- // recurse
2173- s0 = s;
2157+ y += sgptg*nh*nr;
21742158 }
21752159
2176- // Assign the final state to the output buffer
21772160 s_buff[i] = s;
21782161}
21792162
0 commit comments