@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
44__global__ void __launch_bounds__ (splitD, 2 )
55 ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
66 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10- float * __restrict__ dst, const int64_t L) {
11- GGML_UNUSED (src1_nb0);
12- GGML_UNUSED (src2_nb0);
7+ const int32_t * __restrict__ src6, float * __restrict__ dst,
8+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
10+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
1312
1413 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
15- const int bidx = blockIdx .x ; // split along B
16- const int bidy = blockIdx .y ; // split along D
14+ const int bidx = blockIdx .x ; // split along B (sequences)
15+ const int bidy = blockIdx .y ; // split along D (d_inner)
1716 const int tid = threadIdx .x ;
1817 const int wid = tid / 32 ;
1918 const int wtid = tid % 32 ;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
2423 float * smem_A = smem;
2524 float * smem_s0 = smem_A + splitD * stride_sA;
2625
27- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1 );
28- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2 ) + bidy * splitD * sizeof (float ));
26+ const float * s0_block = (const float *) ((const char *) src0 + src6[ bidx] * src0_nb3 + bidy * splitD * src0_nb2 );
27+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3 ) + bidy * splitD * sizeof (float ));
2928 const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof (float ));
3029 const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
31- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2 ));
32- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2 ));
33- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2 ) + bidy * splitD * sizeof (float ));
34- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1 );
30+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3 ));
31+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3 ));
32+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof ( float ) ) + bidy * splitD * sizeof (float ));
33+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2 );
3534
36- const int stride_s0 = src0_nb1 / sizeof (float );
37- const int stride_x = src1_nb1 / sizeof (float );
35+ const int stride_s0 = src0_nb2 / sizeof (float );
36+ const int stride_x = src1_nb2 / sizeof (float );
3837 const int stride_dt = src2_nb1 / sizeof (float );
3938 const int stride_A = src3_nb1 / sizeof (float );
40- const int stride_B = src4_nb1 / sizeof (float );
41- const int stride_C = src5_nb1 / sizeof (float );
39+ const int stride_B = src4_nb2 / sizeof (float );
40+ const int stride_C = src5_nb2 / sizeof (float );
4241 const int stride_s = stride_s0;
43- const int stride_y = stride_x ;
42+ const int stride_y = d_inner ;
4443
4544 // can N not be 16? for example 32?
4645 if (N == 16 ) {
@@ -84,24 +83,157 @@ __global__ void __launch_bounds__(splitD, 2)
8483 }
8584}
8685
86+ // assumes as many threads as d_state
87+ template <int splitH, int d_state>
88+ __global__ void __launch_bounds__ (d_state, 1 )
89+ ssm_scan_f32_group(
90+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
91+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
92+ const int32_t * __restrict__ src6, float * __restrict__ dst,
93+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
94+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
95+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
96+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
97+
98+ const int head_idx = (blockIdx .x * splitH) / d_head;
99+ const int head_off = ((blockIdx .x * splitH) % d_head) * sizeof (float );
100+ const int seq_idx = blockIdx .y ;
101+
102+ const int group_off = (head_idx & (n_group - 1 )) * d_state * sizeof (float );
103+
104+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
105+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx .x * splitH * sizeof (float ));
106+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof (float ));
107+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
108+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
109+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
110+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx .x * splitH;
111+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
112+
113+ // strides across n_seq_tokens
114+ const int stride_x = src1_nb2 / sizeof (float );
115+ const int stride_dt = src2_nb1 / sizeof (float );
116+ const int stride_B = src4_nb2 / sizeof (float );
117+ const int stride_C = src5_nb2 / sizeof (float );
118+ const int stride_y = n_head * d_head;
119+
120+ float state[splitH];
121+ // for the parallel accumulation
122+ __shared__ float stateC[splitH * d_state];
123+
124+ #pragma unroll
125+ for (int j = 0 ; j < splitH; j++) {
126+ state[j] = s0_block[j * d_state + threadIdx .x ];
127+ }
128+
129+ for (int64_t i = 0 ; i < n_tok; i++) {
130+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
131+ // TODO: only calculate B and C once per head group
132+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
133+ float dt_soft_plus = dt_block[i * stride_dt];
134+ if (dt_soft_plus <= 20 .0f ) {
135+ dt_soft_plus = log1pf (expf (dt_soft_plus));
136+ }
137+ const float dA = expf (dt_soft_plus * A_block[0 ]);
138+ const float B = B_block[i * stride_B + threadIdx .x ];
139+ const float C = C_block[i * stride_C + threadIdx .x ];
140+
141+ // across d_head
142+ #pragma unroll
143+ for (int j = 0 ; j < splitH; j++) {
144+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
145+
146+ state[j] = (state[j] * dA) + (B * x_dt);
147+
148+ stateC[j * d_state + threadIdx .x ] = state[j] * C;
149+ }
150+
151+ __syncthreads ();
152+
153+ // parallel accumulation for stateC
154+ // TODO: simplify
155+ {
156+ static_assert ((d_state & -d_state) == d_state, " the state size has to be a power of 2" );
157+ static_assert ((splitH & -splitH) == splitH, " splitH has to be a power of 2" );
158+
159+ // reduce until w matches the warp size
160+ // TODO: does this work even when the physical warp size is 64?
161+ #pragma unroll
162+ for (int w = d_state; w > WARP_SIZE; w >>= 1 ) {
163+ // (assuming there are d_state threads)
164+ #pragma unroll
165+ for (int j = 0 ; j < ((w >> 1 ) * splitH + d_state - 1 ) / d_state; j++) {
166+ // TODO: check for bank conflicts
167+ const int k = (threadIdx .x % (w >> 1 )) + (d_state * (threadIdx .x / (w >> 1 ))) + j * d_state * (d_state / (w >> 1 ));
168+ stateC[k] += stateC[k + (w >> 1 )];
169+
170+ }
171+ __syncthreads ();
172+ }
173+
174+ static_assert (splitH >= d_state / WARP_SIZE);
175+
176+ #pragma unroll
177+ for (int j = 0 ; j < splitH / (d_state / WARP_SIZE); j++) {
178+ float y = stateC[(threadIdx .x % WARP_SIZE) + d_state * (threadIdx .x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
179+ y = warp_reduce_sum (y);
180+
181+ // store the above accumulations
182+ if (threadIdx .x % WARP_SIZE == 0 ) {
183+ const int k = threadIdx .x / WARP_SIZE + j * (d_state / WARP_SIZE);
184+ y_block[i * stride_y + k] = y;
185+ }
186+ }
187+ }
188+ }
189+
190+ // write back the state
191+ #pragma unroll
192+ for (int j = 0 ; j < splitH; j++) {
193+ s_block[j * d_state + threadIdx .x ] = state[j];
194+ }
195+ }
196+
87197static void ssm_scan_f32_cuda (const float * src0, const float * src1, const float * src2, const float * src3,
88- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2 ,
89- const int src1_nb0 , const int src1_nb1 , const int src1_nb2, const int src1_nb3,
90- const int src2_nb0 , const int src2_nb1 , const int src2_nb2 , const int src3_nb1 ,
91- const int src4_nb1 , const int src4_nb2 , const int src5_nb1 , const int src5_nb2 ,
92- float * dst, const int64_t N , const int64_t D , const int64_t L , const int64_t B ,
198+ const float * src4, const float * src5, const int32_t * src6, float * dst ,
199+ const int src0_nb2 , const int src0_nb3 , const int src1_nb2, const int src1_nb3, const int src2_nb1 ,
200+ const int src2_nb2 , const int src3_nb1 , const int src4_nb2 , const int src4_nb3, const int src5_nb2 ,
201+ const int src5_nb3 , const int64_t s_off , const int64_t d_state , const int64_t head_dim ,
202+ const int64_t n_head , const int64_t n_group , const int64_t n_tok , const int64_t n_seq ,
93203 cudaStream_t stream) {
94204 const int threads = 128 ;
95- // todo: consider D cannot be divided,does this situation exist?
96- GGML_ASSERT (D % threads == 0 );
97- const dim3 blocks (B, (D + threads - 1 ) / threads, 1 );
98- const int smem_size = (threads * (N + 1 ) * 2 ) * sizeof (float );
99- if (N == 16 ) {
100- ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
101- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
205+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206+ if (src3_nb1 == sizeof (float )) {
207+ // Mamba2
208+ if (d_state == 128 ) {
209+ GGML_ASSERT (d_state % threads == 0 );
210+ // NOTE: can be any power of two between 4 and 64
211+ const int splitH = 16 ;
212+ GGML_ASSERT (head_dim % splitH == 0 );
213+ const dim3 blocks ((n_head * head_dim + (splitH - 1 )) / splitH, n_seq, 1 );
214+ ssm_scan_f32_group<16 , 128 ><<<blocks, threads, 0 , stream>>> (
215+ src0, src1, src2, src3, src4, src5, src6, dst,
216+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218+ } else {
219+ GGML_ABORT (" doesn't support d_state!=128." );
220+ }
103221 } else {
104- GGML_ABORT (" doesn't support N!=16." );
222+ // Mamba1
223+ // todo: consider n_head cannot be divided, does this situation exist?
224+ GGML_ASSERT (n_head % threads == 0 );
225+ GGML_ASSERT (head_dim == 1 );
226+ GGML_ASSERT (n_group == 1 );
227+ const dim3 blocks (n_seq, (n_head + threads - 1 ) / threads, 1 );
228+ const int smem_size = (threads * (d_state + 1 ) * 2 ) * sizeof (float );
229+ if (d_state == 16 ) {
230+ ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
231+ src0, src1, src2, src3, src4, src5, src6, dst,
232+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
233+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
234+ } else {
235+ GGML_ABORT (" doesn't support d_state!=16." );
236+ }
105237 }
106238}
107239
@@ -112,44 +244,42 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112244 const struct ggml_tensor * src3 = dst->src [3 ]; // A
113245 const struct ggml_tensor * src4 = dst->src [4 ]; // B
114246 const struct ggml_tensor * src5 = dst->src [5 ]; // C
115-
116- // const int64_t d_state = src0->ne[0];
117- // const int64_t d_inner = src0->ne[1];
118- // const int64_t l = src1->ne[1];
119- // const int64_t b = src0->ne[2];
247+ const struct ggml_tensor * src6 = dst->src [6 ]; // ids
120248
121249 const int64_t nc = src0->ne [0 ]; // d_state
122- const int64_t nr = src0->ne [1 ]; // d_inner
123- const int64_t n_t = src1->ne [1 ]; // number of tokens per sequence
124- const int64_t n_s = src0->ne [2 ]; // number of sequences in the batch
250+ const int64_t nr = src0->ne [1 ]; // head_dim or 1
251+ const int64_t nh = src1->ne [1 ]; // n_head
252+ const int64_t ng = src4->ne [1 ]; // n_group
253+ const int64_t n_t = src1->ne [2 ]; // number of tokens per sequence
254+ const int64_t n_s = src1->ne [3 ]; // number of sequences in the batch
255+
256+ const int64_t s_off = ggml_nelements (src1) * sizeof (float );
125257
126- GGML_ASSERT (ggml_nelements (src1) + ggml_nelements (src0) == ggml_nelements (dst));
258+ GGML_ASSERT (ggml_nelements (src1) + nc*nr*nh*n_s == ggml_nelements (dst));
127259 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
128260 GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
129261 GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
130262 GGML_ASSERT (src3->nb [0 ] == sizeof (float ));
131263 GGML_ASSERT (src4->nb [0 ] == sizeof (float ));
132264 GGML_ASSERT (src5->nb [0 ] == sizeof (float ));
133- // required for the dot product between s and C
134- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
135- // required for per-sequence offsets for states
136- GGML_ASSERT (src0->nb [2 ] == src0->ne [0 ] * src0->ne [1 ] * sizeof (float ));
137- // required to get correct offset for state destination (i.e. src1->nb[3])
138- GGML_ASSERT (src1->nb [3 ] == src1->ne [0 ] * src1->ne [1 ] * src1->ne [2 ] * sizeof (float ));
265+ GGML_ASSERT (src6->nb [0 ] == sizeof (int32_t ));
139266
140267 const float * src0_d = (const float *) src0->data ;
141268 const float * src1_d = (const float *) src1->data ;
142269 const float * src2_d = (const float *) src2->data ;
143270 const float * src3_d = (const float *) src3->data ;
144271 const float * src4_d = (const float *) src4->data ;
145272 const float * src5_d = (const float *) src5->data ;
273+ const int32_t * src6_d = (const int32_t *) src6->data ;
146274 float * dst_d = (float *) dst->data ;
147275 cudaStream_t stream = ctx.stream ();
148276
149277 GGML_ASSERT (src0->type == GGML_TYPE_F32);
278+ GGML_ASSERT (src6->type == GGML_TYPE_I32);
150279 GGML_ASSERT (dst->type == GGML_TYPE_F32);
151280
152- ssm_scan_f32_cuda (src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb [1 ], src0->nb [2 ], src1->nb [0 ],
153- src1->nb [1 ], src1->nb [2 ], src1->nb [3 ], src2->nb [0 ], src2->nb [1 ], src2->nb [2 ], src3->nb [1 ],
154- src4->nb [1 ], src4->nb [2 ], src5->nb [1 ], src5->nb [2 ], dst_d, nc, nr, n_t , n_s, stream);
281+ ssm_scan_f32_cuda (src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
282+ src0->nb [2 ], src0->nb [3 ], src1->nb [2 ], src1->nb [3 ], src2->nb [1 ], src2->nb [2 ],
283+ src3->nb [1 ], src4->nb [2 ], src4->nb [3 ], src5->nb [2 ], src5->nb [3 ],
284+ s_off, nc, nr, nh, ng, n_t , n_s, stream);
155285}
0 commit comments