@@ -2063,7 +2063,7 @@ kernel void kernel_ssm_conv_f32_f32_4(
20632063 x[0 ] = sumf;
20642064}
20652065
2066- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
2066+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
20672067kernel void kernel_ssm_scan_f32 (
20682068 constant ggml_metal_kargs_ssm_scan & args,
20692069 device const void * src0,
@@ -2081,120 +2081,8 @@ kernel void kernel_ssm_scan_f32(
20812081 ushort tiisg[[thread_index_in_simdgroup]],
20822082 ushort sgptg[[simdgroups_per_threadgroup]],
20832083 uint3 tgpg[[threadgroups_per_grid]]) {
2084- const int64_t i0 = tpitg.x ;
2085- const int64_t i1 = 0 ;
2086- const int64_t ir = tgpig.x ; // current head
2087- const int64_t i3 = tgpig.y ; // current seq
2088-
2089- const uint64_t nb00 = sizeof (float );
2090- const uint64_t nb10 = sizeof (float );
2091- const uint64_t nb20 = sizeof (float );
2092-
2093- const int64_t nc = args.d_state ;
2094- const int64_t nr = args.d_inner ;
2095- const int64_t nh = args.n_head ;
2096- const int64_t ng = args.n_group ;
2097- const int64_t n_t = args.n_seq_tokens ;
2098-
2099- const int64_t s_off = args.s_off ;
21002084
2101- device const int32_t * ids = (device const int32_t *) src6;
2102-
2103- device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
2104- device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2105- const int64_t i = i0 + i1*nc;
2106- const int64_t g = ir / (nh / ng); // repeat_interleave
2107- float s0 = s0_buff[i];
2108- float s = 0 .0f ;
2109-
2110- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 );
2111- device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
2112- device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
2113- device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43 );
2114- device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53 );
2115- device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
2116-
2117- for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
2118- device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
2119- device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
2120- device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
2121- device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
2122- device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2123-
2124- const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
2125- const float x_dt = x[0 ] * dt_soft_plus;
2126-
2127- s = (s0 * exp (dt_soft_plus * A[i0])) + (B[i0] * x_dt);
2128-
2129- // Parallel sum: This relies on the fact that this kernel will be
2130- // dispatched with each threadgroup having (d_state, 1, 1) threads which
2131- // are subdivided into SIMD groups of size `sgptg`. The goal is to
2132- // compute y = sum({state * C[i] for i in range(d_state)}).
2133- // To parallelize this effectively, we first use simd_sum over each SIMD
2134- // group to compute the sum of each SIMD group, then place the result in
2135- // the SIMD group's indexed bucket in the shared memory. We then sum
2136- // over the individual group sums to compute the final sum.
2137-
2138- // Computed for each thread
2139- float sumf = s * C[i0];
2140-
2141- // Sum the threads in the simd group => simd sum
2142- sumf = simd_sum (sumf);
2143-
2144- if (sgptg > 1 ) {
2145-
2146- // Once per simd group, place the group sum into the shared buffer
2147- if (tiisg == 0 ) {
2148- shared[sgitg] = sumf;
2149- }
2150-
2151- // Wait for all threads in the threadgroup to reach this point. This
2152- // ensures that all elements of the shared buffer are populated with the
2153- // sum of the individual simd groups.
2154- threadgroup_barrier (mem_flags::mem_threadgroup);
2155-
2156- // For simd group 0 at indices < num simd groups, extract the shared
2157- // simd sum
2158- sumf = 0 .0f ;
2159- if (sgitg == 0 ) {
2160- if (tiisg < sgptg) {
2161- sumf = shared[tiisg];
2162- }
2163- sumf = simd_sum (sumf);
2164- if (tiisg == 0 ) {
2165- y[0 ] = sumf;
2166- }
2167- }
2168- } else if (tiisg == 0 ) {
2169- y[0 ] = sumf;
2170- }
2171-
2172- // recurse
2173- s0 = s;
2174- }
2175-
2176- // Assign the final state to the output buffer
2177- s_buff[i] = s;
2178- }
2179-
2180- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2181- kernel void kernel_ssm_scan_group_f32 (
2182- constant ggml_metal_kargs_ssm_scan & args,
2183- device const void * src0,
2184- device const void * src1,
2185- device const void * src2,
2186- device const void * src3,
2187- device const void * src4,
2188- device const void * src5,
2189- device const void * src6,
2190- device float * dst,
2191- threadgroup float * shared [[threadgroup(0 )]],
2192- uint3 tgpig[[threadgroup_position_in_grid]],
2193- uint3 tpitg[[thread_position_in_threadgroup]],
2194- ushort sgitg[[simdgroup_index_in_threadgroup]],
2195- ushort tiisg[[thread_index_in_simdgroup]],
2196- ushort sgptg[[simdgroups_per_threadgroup]],
2197- uint3 tgpg[[threadgroups_per_grid]]) {
2085+ shared[tpitg.x ] = 0 .0f ;
21982086
21992087 const int64_t i0 = tpitg.x ;
22002088 const int64_t i1 = tgpig.x ;
@@ -2222,23 +2110,28 @@ kernel void kernel_ssm_scan_group_f32(
22222110 float s0 = s0_buff[i];
22232111 float s = 0 .0f ;
22242112
2225- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
2226- device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
2227- device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
2228- device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43 );
2229- device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53 );
2230- device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
2113+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {ne30, nh}
2114+ const float A0 = A[i0%args.ne30 ];
22312115
2232- for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
2233- device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
2234- device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
2235- device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
2236- device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
2237- device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2116+ device const char * x_block = ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
2117+ device const char * dt_block = ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
2118+ device const char * B_block = ((device const char *) src4 + g*args.nb41 + i3*args.nb43 );
2119+ device const char * C_block = ((device const char *) src5 + g*args.nb51 + i3*args.nb53 );
2120+ device char * y_block = ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
2121+
2122+ threadgroup_barrier (mem_flags::mem_threadgroup);
22382123
2239- const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
2124+ for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
2125+ device const float * x = (device const float *) (x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
2126+ device const float * dt = (device const float *) (dt_block + i2*args.nb21 ); // {nh, nt, ns}
2127+ device const float * B = (device const float *) (B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
2128+ device const float * C = (device const float *) (C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
2129+ device float * y = (device float *) (y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2130+
2131+ const float dt0 = dt[0 ];
2132+ const float dt_soft_plus = dt0 <= 20 .0f ? log (1 .0f + exp (dt0)) : dt0;
22402133 const float x_dt = x[0 ] * dt_soft_plus;
2241- const float dA = exp (dt_soft_plus * A[ 0 ] );
2134+ const float dA = exp (dt_soft_plus * A0 );
22422135
22432136 s = (s0 * dA) + (B[i0] * x_dt);
22442137
@@ -2269,12 +2162,8 @@ kernel void kernel_ssm_scan_group_f32(
22692162
22702163 // For simd group 0 at indices < num simd groups, extract the shared
22712164 // simd sum
2272- sumf = 0 .0f ;
22732165 if (sgitg == 0 ) {
2274- if (tiisg < sgptg) {
2275- sumf = shared[tiisg];
2276- }
2277- sumf = simd_sum (sumf);
2166+ sumf = simd_sum (shared[tiisg]);
22782167 if (tiisg == 0 ) {
22792168 y[0 ] = sumf;
22802169 }
0 commit comments