1+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2+ #define USE_CUB
3+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4+
5+ #ifdef USE_CUB
6+ #include < cub/cub.cuh>
7+ using namespace cub ;
8+ #endif // USE_CUB
9+
110#include " ssm-scan.cuh"
211
3- template <size_t splitD, size_t N>
4- __global__ void __launch_bounds__ (splitD, 2 )
5- ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6- const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
12+ // We would like to keep pragma unroll for cases where L_template is not 0,
13+ // so we suppress the clang transformation warning.
14+ #ifdef __clang__
15+ #pragma clang diagnostic push
16+ #pragma clang diagnostic ignored "-Wpass-failed"
17+ #endif // __clang__
18+ template <size_t splitD, size_t N, size_t L_template>
19+ __global__ void __launch_bounds__ (splitD, 1 )
20+ ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21+ const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
722 const int32_t * __restrict__ src6, float * __restrict__ dst,
823 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
924 const int src2_nb1, const int src2_nb2, const int src3_nb1,
1025 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11- const int64_t s_off, const int64_t d_inner, const int64_t L) {
12-
13- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
14- const int bidx = blockIdx .x ; // split along B (sequences)
15- const int bidy = blockIdx .y ; // split along D (d_inner)
16- const int tid = threadIdx .x ;
17- const int wid = tid / 32 ;
18- const int wtid = tid % 32 ;
19-
20- extern __shared__ float smem[];
21- const int stride_sA = N + 1 ;
22- const int stride_ss0 = N + 1 ;
23- float * smem_A = smem;
24- float * smem_s0 = smem_A + splitD * stride_sA;
25-
26- const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof (float ));
28- const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof (float ));
29- const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
30- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32- float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof (float )) + bidy * splitD * sizeof (float ));
33- float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
34-
35- const int stride_s0 = src0_nb2 / sizeof (float );
36- const int stride_x = src1_nb2 / sizeof (float );
26+ const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27+ {
28+ const size_t L = L_template == 0 ? L_param : L_template;
29+ const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx .x ] * src0_nb3 + blockIdx .y * splitD * src0_nb2);
30+ const float *x_block = (const float *)((const char *)src1 + (blockIdx .x * src1_nb3) + blockIdx .y * splitD * sizeof (float ));
31+ const float *dt_block = (const float *)((const char *)src2 + (blockIdx .x * src2_nb2) + blockIdx .y * splitD * sizeof (float ));
32+ const float *A_block = (const float *)((const char *)src3 + blockIdx .y * splitD * src3_nb1);
33+ const float *B_block = (const float *)((const char *)src4 + (blockIdx .x * src4_nb3));
34+ const float *C_block = (const float *)((const char *)src5 + (blockIdx .x * src5_nb3));
35+ float *y_block = (float *)((char *)dst + (blockIdx .x * d_inner * L * sizeof (float )) + blockIdx .y * splitD * sizeof (float ));
36+ float *s_block = (float *)((char *)dst + s_off + blockIdx .x * src0_nb3 + blockIdx .y * splitD * src0_nb2);
37+
38+ const int stride_x = src1_nb2 / sizeof (float );
3739 const int stride_dt = src2_nb1 / sizeof (float );
38- const int stride_A = src3_nb1 / sizeof (float );
39- const int stride_B = src4_nb2 / sizeof (float );
40- const int stride_C = src5_nb2 / sizeof (float );
41- const int stride_s = stride_s0;
42- const int stride_y = d_inner;
40+ const int stride_B = src4_nb2 / sizeof (float );
41+ const int stride_C = src5_nb2 / sizeof (float );
42+ const int stride_y = d_inner;
4343
44- // can N not be 16? for example 32?
45- if (N == 16 ) {
46- #pragma unroll
47- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
48- float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49- // todo: bank conflict
50- // I am always confused with how to use the swizzling method to solve
51- // bank conflit. Hoping somebody can tell me.
52- smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
53- }
44+ float regA[N];
45+ float regs0[N];
46+
47+ __shared__ float smemB[N];
48+ __shared__ float smemC[N];
49+
50+ #ifdef USE_CUB
51+ using BlockLoad = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52+ using BlockStore = cub::BlockStore<float , splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53+
54+ union CubTempStorage {
55+ typename BlockLoad::TempStorage load_temp;
56+ typename BlockStore::TempStorage store_temp;
57+ };
58+ __shared__ CubTempStorage cub_temp_storage;
59+
60+ BlockLoad (cub_temp_storage.load_temp ).Load (A_block, regA);
61+ BlockLoad (cub_temp_storage.load_temp ).Load (s0_block, regs0);
62+ #else
63+ const int stride_s0 = src0_nb2 / sizeof (float );
64+ const int stride_A = src3_nb1 / sizeof (float );
5465#pragma unroll
55- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
56- float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57- smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value ;
58- }
66+ for (size_t n = 0 ; n < N; ++n)
67+ {
68+ regA[n] = A_block[ threadIdx . x * stride_A + n] ;
69+ regs0[n] = s0_block[ threadIdx . x * stride_s0 + n];
5970 }
71+ #endif
6072
61- __syncthreads ();
73+ #pragma unroll
74+ for (size_t i = 0 ; i < L; i++)
75+ {
76+ if (threadIdx .x < N)
77+ {
78+ smemB[threadIdx .x ] = B_block[i * stride_B + threadIdx .x ];
79+ smemC[threadIdx .x ] = C_block[i * stride_C + threadIdx .x ];
80+ }
81+ __syncthreads ();
6282
63- for ( int64_t i = 0 ; i < L; i++) {
64- float dt_soft_plus = dt_block[i * stride_dt + tid];
65- if (dt_soft_plus <= 20 . 0f ) {
66- dt_soft_plus = log1pf (exp (dt_soft_plus));
83+ float dt_soft_plus = dt_block[i * stride_dt + threadIdx . x ];
84+ if ( dt_soft_plus <= 20 . 0f )
85+ {
86+ dt_soft_plus = log1pf (expf (dt_soft_plus));
6787 }
68- float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
88+ float x_dt = x_block[i * stride_x + threadIdx .x ] * dt_soft_plus;
89+
6990 float sumf = 0 .0f ;
7091#pragma unroll
71- for (size_t j = 0 ; j < N; j++) {
72- float state = (smem_s0[tid * stride_ss0 + j] * expf (dt_soft_plus * smem_A[tid * stride_sA + j])) +
73- (B_block[i * stride_B + j] * x_dt);
74- sumf += state * C_block[i * stride_C + j];
75- if (i == L - 1 ) {
76- s_block[tid * stride_s + j] = state;
77- } else {
78- smem_s0[tid * stride_ss0 + j] = state;
79- }
92+ for (size_t n = 0 ; n < N; n++)
93+ {
94+ float state = regs0[n] * expf (dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95+ sumf += state * smemC[n];
96+ regs0[n] = state;
8097 }
81- __syncthreads ();
82- y_block[i * stride_y + tid] = sumf;
98+ y_block[i * stride_y + threadIdx .x ] = sumf;
8399 }
100+
101+ #ifdef USE_CUB
102+ BlockStore (cub_temp_storage.store_temp ).Store (s_block, regs0);
103+ #else
104+ const int stride_s = stride_s0;
105+ #pragma unroll
106+ for (size_t n = 0 ; n < N; ++n)
107+ {
108+ s_block[threadIdx .x * stride_s + n] = regs0[n];
109+ }
110+ #endif
84111}
112+ #ifdef __clang__
113+ #pragma clang diagnostic pop
114+ #endif // __clang__
85115
86116// assumes as many threads as d_state
87117template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201231 const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202232 const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203233 cudaStream_t stream) {
234+ const int threads = 128 ;
204235 // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205236 if (src3_nb1 == sizeof (float )) {
206237 // Mamba-2
207238 if (d_state == 128 ) {
208- const int threads = 128 ;
209239 GGML_ASSERT (d_state % threads == 0 );
210240 // NOTE: can be any power of two between 4 and 64
211241 const int splitH = 16 ;
@@ -229,18 +259,70 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
229259 GGML_ABORT (" doesn't support d_state!=(128 or 256)." );
230260 }
231261 } else {
232- const int threads = 128 ;
233262 // Mamba-1
234263 GGML_ASSERT (n_head % threads == 0 );
235264 GGML_ASSERT (head_dim == 1 );
236265 GGML_ASSERT (n_group == 1 );
237266 const dim3 blocks (n_seq, (n_head + threads - 1 ) / threads, 1 );
238267 const int smem_size = (threads * (d_state + 1 ) * 2 ) * sizeof (float );
239268 if (d_state == 16 ) {
240- ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
241- src0, src1, src2, src3, src4, src5, src6, dst,
269+ switch (n_tok)
270+ {
271+ case 1 :
272+ ssm_scan_f32<threads, 16 , 1 ><<<blocks, threads, smem_size, stream>>> (
273+ src0, src1, src2, src3, src4, src5, src6, dst,
274+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
275+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
276+ break ;
277+ case 2 :
278+ ssm_scan_f32<threads, 16 , 2 ><<<blocks, threads, smem_size, stream>>> (
279+ src0, src1, src2, src3, src4, src5, src6, dst,
280+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
281+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
282+ break ;
283+ case 3 :
284+ ssm_scan_f32<threads, 16 , 3 ><<<blocks, threads, smem_size, stream>>> (
285+ src0, src1, src2, src3, src4, src5, src6, dst,
286+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
287+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
288+ break ;
289+ case 4 :
290+ ssm_scan_f32<threads, 16 , 4 ><<<blocks, threads, smem_size, stream>>> (
291+ src0, src1, src2, src3, src4, src5, src6, dst,
292+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
293+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
294+ break ;
295+ case 5 :
296+ ssm_scan_f32<threads, 16 , 5 ><<<blocks, threads, smem_size, stream>>> (
297+ src0, src1, src2, src3, src4, src5, src6, dst,
242298 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
243299 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
300+ break ;
301+ case 6 :
302+ ssm_scan_f32<threads, 16 , 6 ><<<blocks, threads, smem_size, stream>>> (
303+ src0, src1, src2, src3, src4, src5, src6, dst,
304+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
305+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
306+ break ;
307+ case 7 :
308+ ssm_scan_f32<threads, 16 , 7 ><<<blocks, threads, smem_size, stream>>> (
309+ src0, src1, src2, src3, src4, src5, src6, dst,
310+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
311+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
312+ break ;
313+ case 8 :
314+ ssm_scan_f32<threads, 16 , 8 ><<<blocks, threads, smem_size, stream>>> (
315+ src0, src1, src2, src3, src4, src5, src6, dst,
316+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
317+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
318+ break ;
319+ default :
320+ ssm_scan_f32<threads, 16 , 0 ><<<blocks, threads, smem_size, stream>>> (
321+ src0, src1, src2, src3, src4, src5, src6, dst,
322+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
323+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
324+ break ;
325+ }
244326 } else {
245327 GGML_ABORT (" doesn't support d_state!=16." );
246328 }
0 commit comments