1
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2
+ #define USE_CUB
3
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4
+
5
+ #ifdef USE_CUB
6
+ #include < cub/cub.cuh>
7
+ using namespace cub ;
8
+ #endif // USE_CUB
9
+
1
10
#include " ssm-scan.cuh"
2
11
3
- template <size_t splitD, size_t N>
4
- __global__ void __launch_bounds__ (splitD, 2 )
5
- ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6
- const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
12
+ // We would like to keep pragma unroll for cases where L_template is not 0,
13
+ // so we suppress the clang transformation warning.
14
+ #ifdef __clang__
15
+ #pragma clang diagnostic push
16
+ #pragma clang diagnostic ignored "-Wpass-failed"
17
+ #endif // __clang__
18
+ template <size_t splitD, size_t N, size_t L_template>
19
+ __global__ void __launch_bounds__ (splitD, 1 )
20
+ ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21
+ const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
7
22
const int32_t * __restrict__ src6, float * __restrict__ dst,
8
23
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9
24
const int src2_nb1, const int src2_nb2, const int src3_nb1,
10
25
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11
- const int64_t s_off, const int64_t d_inner, const int64_t L) {
12
-
13
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
14
- const int bidx = blockIdx .x ; // split along B (sequences)
15
- const int bidy = blockIdx .y ; // split along D (d_inner)
16
- const int tid = threadIdx .x ;
17
- const int wid = tid / 32 ;
18
- const int wtid = tid % 32 ;
19
-
20
- extern __shared__ float smem[];
21
- const int stride_sA = N + 1 ;
22
- const int stride_ss0 = N + 1 ;
23
- float * smem_A = smem;
24
- float * smem_s0 = smem_A + splitD * stride_sA;
25
-
26
- const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof (float ));
28
- const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof (float ));
29
- const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
30
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32
- float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof (float )) + bidy * splitD * sizeof (float ));
33
- float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
34
-
35
- const int stride_s0 = src0_nb2 / sizeof (float );
36
- const int stride_x = src1_nb2 / sizeof (float );
26
+ const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27
+ {
28
+ const size_t L = L_template == 0 ? L_param : L_template;
29
+ const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx .x ] * src0_nb3 + blockIdx .y * splitD * src0_nb2);
30
+ const float *x_block = (const float *)((const char *)src1 + (blockIdx .x * src1_nb3) + blockIdx .y * splitD * sizeof (float ));
31
+ const float *dt_block = (const float *)((const char *)src2 + (blockIdx .x * src2_nb2) + blockIdx .y * splitD * sizeof (float ));
32
+ const float *A_block = (const float *)((const char *)src3 + blockIdx .y * splitD * src3_nb1);
33
+ const float *B_block = (const float *)((const char *)src4 + (blockIdx .x * src4_nb3));
34
+ const float *C_block = (const float *)((const char *)src5 + (blockIdx .x * src5_nb3));
35
+ float *y_block = (float *)((char *)dst + (blockIdx .x * d_inner * L * sizeof (float )) + blockIdx .y * splitD * sizeof (float ));
36
+ float *s_block = (float *)((char *)dst + s_off + blockIdx .x * src0_nb3 + blockIdx .y * splitD * src0_nb2);
37
+
38
+ const int stride_x = src1_nb2 / sizeof (float );
37
39
const int stride_dt = src2_nb1 / sizeof (float );
38
- const int stride_A = src3_nb1 / sizeof (float );
39
- const int stride_B = src4_nb2 / sizeof (float );
40
- const int stride_C = src5_nb2 / sizeof (float );
41
- const int stride_s = stride_s0;
42
- const int stride_y = d_inner;
40
+ const int stride_B = src4_nb2 / sizeof (float );
41
+ const int stride_C = src5_nb2 / sizeof (float );
42
+ const int stride_y = d_inner;
43
43
44
- // can N not be 16? for example 32?
45
- if (N == 16 ) {
46
- #pragma unroll
47
- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
48
- float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49
- // todo: bank conflict
50
- // I am always confused with how to use the swizzling method to solve
51
- // bank conflit. Hoping somebody can tell me.
52
- smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value;
53
- }
44
+ float regA[N];
45
+ float regs0[N];
46
+
47
+ __shared__ float smemB[N];
48
+ __shared__ float smemC[N];
49
+
50
+ #ifdef USE_CUB
51
+ using BlockLoad = cub::BlockLoad<float , splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52
+ using BlockStore = cub::BlockStore<float , splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
+
54
+ union CubTempStorage {
55
+ typename BlockLoad::TempStorage load_temp;
56
+ typename BlockStore::TempStorage store_temp;
57
+ };
58
+ __shared__ CubTempStorage cub_temp_storage;
59
+
60
+ BlockLoad (cub_temp_storage.load_temp ).Load (A_block, regA);
61
+ BlockLoad (cub_temp_storage.load_temp ).Load (s0_block, regs0);
62
+ #else
63
+ const int stride_s0 = src0_nb2 / sizeof (float );
64
+ const int stride_A = src3_nb1 / sizeof (float );
54
65
#pragma unroll
55
- for (size_t i = 0 ; i < splitD / 4 ; i += 2 ) {
56
- float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57
- smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16 ) > 0 ? 1 : 0 )] = value ;
58
- }
66
+ for (size_t n = 0 ; n < N; ++n)
67
+ {
68
+ regA[n] = A_block[ threadIdx . x * stride_A + n] ;
69
+ regs0[n] = s0_block[ threadIdx . x * stride_s0 + n];
59
70
}
71
+ #endif
60
72
61
- __syncthreads ();
73
+ #pragma unroll
74
+ for (size_t i = 0 ; i < L; i++)
75
+ {
76
+ if (threadIdx .x < N)
77
+ {
78
+ smemB[threadIdx .x ] = B_block[i * stride_B + threadIdx .x ];
79
+ smemC[threadIdx .x ] = C_block[i * stride_C + threadIdx .x ];
80
+ }
81
+ __syncthreads ();
62
82
63
- for ( int64_t i = 0 ; i < L; i++) {
64
- float dt_soft_plus = dt_block[i * stride_dt + tid];
65
- if (dt_soft_plus <= 20 . 0f ) {
66
- dt_soft_plus = log1pf (exp (dt_soft_plus));
83
+ float dt_soft_plus = dt_block[i * stride_dt + threadIdx . x ];
84
+ if ( dt_soft_plus <= 20 . 0f )
85
+ {
86
+ dt_soft_plus = log1pf (expf (dt_soft_plus));
67
87
}
68
- float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
88
+ float x_dt = x_block[i * stride_x + threadIdx .x ] * dt_soft_plus;
89
+
69
90
float sumf = 0 .0f ;
70
91
#pragma unroll
71
- for (size_t j = 0 ; j < N; j++) {
72
- float state = (smem_s0[tid * stride_ss0 + j] * expf (dt_soft_plus * smem_A[tid * stride_sA + j])) +
73
- (B_block[i * stride_B + j] * x_dt);
74
- sumf += state * C_block[i * stride_C + j];
75
- if (i == L - 1 ) {
76
- s_block[tid * stride_s + j] = state;
77
- } else {
78
- smem_s0[tid * stride_ss0 + j] = state;
79
- }
92
+ for (size_t n = 0 ; n < N; n++)
93
+ {
94
+ float state = regs0[n] * expf (dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95
+ sumf += state * smemC[n];
96
+ regs0[n] = state;
80
97
}
81
- __syncthreads ();
82
- y_block[i * stride_y + tid] = sumf;
98
+ y_block[i * stride_y + threadIdx .x ] = sumf;
83
99
}
100
+
101
+ #ifdef USE_CUB
102
+ BlockStore (cub_temp_storage.store_temp ).Store (s_block, regs0);
103
+ #else
104
+ const int stride_s = stride_s0;
105
+ #pragma unroll
106
+ for (size_t n = 0 ; n < N; ++n)
107
+ {
108
+ s_block[threadIdx .x * stride_s + n] = regs0[n];
109
+ }
110
+ #endif
84
111
}
112
+ #ifdef __clang__
113
+ #pragma clang diagnostic pop
114
+ #endif // __clang__
85
115
86
116
// assumes as many threads as d_state
87
117
template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201
231
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
232
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203
233
cudaStream_t stream) {
234
+ const int threads = 128 ;
204
235
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205
236
if (src3_nb1 == sizeof (float )) {
206
237
// Mamba-2
207
238
if (d_state == 128 ) {
208
- const int threads = 128 ;
209
239
GGML_ASSERT (d_state % threads == 0 );
210
240
// NOTE: can be any power of two between 4 and 64
211
241
const int splitH = 16 ;
@@ -229,18 +259,70 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
229
259
GGML_ABORT (" doesn't support d_state!=(128 or 256)." );
230
260
}
231
261
} else {
232
- const int threads = 128 ;
233
262
// Mamba-1
234
263
GGML_ASSERT (n_head % threads == 0 );
235
264
GGML_ASSERT (head_dim == 1 );
236
265
GGML_ASSERT (n_group == 1 );
237
266
const dim3 blocks (n_seq, (n_head + threads - 1 ) / threads, 1 );
238
267
const int smem_size = (threads * (d_state + 1 ) * 2 ) * sizeof (float );
239
268
if (d_state == 16 ) {
240
- ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
241
- src0, src1, src2, src3, src4, src5, src6, dst,
269
+ switch (n_tok)
270
+ {
271
+ case 1 :
272
+ ssm_scan_f32<threads, 16 , 1 ><<<blocks, threads, smem_size, stream>>> (
273
+ src0, src1, src2, src3, src4, src5, src6, dst,
274
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
275
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
276
+ break ;
277
+ case 2 :
278
+ ssm_scan_f32<threads, 16 , 2 ><<<blocks, threads, smem_size, stream>>> (
279
+ src0, src1, src2, src3, src4, src5, src6, dst,
280
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
281
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
282
+ break ;
283
+ case 3 :
284
+ ssm_scan_f32<threads, 16 , 3 ><<<blocks, threads, smem_size, stream>>> (
285
+ src0, src1, src2, src3, src4, src5, src6, dst,
286
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
287
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
288
+ break ;
289
+ case 4 :
290
+ ssm_scan_f32<threads, 16 , 4 ><<<blocks, threads, smem_size, stream>>> (
291
+ src0, src1, src2, src3, src4, src5, src6, dst,
292
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
293
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
294
+ break ;
295
+ case 5 :
296
+ ssm_scan_f32<threads, 16 , 5 ><<<blocks, threads, smem_size, stream>>> (
297
+ src0, src1, src2, src3, src4, src5, src6, dst,
242
298
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
243
299
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
300
+ break ;
301
+ case 6 :
302
+ ssm_scan_f32<threads, 16 , 6 ><<<blocks, threads, smem_size, stream>>> (
303
+ src0, src1, src2, src3, src4, src5, src6, dst,
304
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
305
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
306
+ break ;
307
+ case 7 :
308
+ ssm_scan_f32<threads, 16 , 7 ><<<blocks, threads, smem_size, stream>>> (
309
+ src0, src1, src2, src3, src4, src5, src6, dst,
310
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
311
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
312
+ break ;
313
+ case 8 :
314
+ ssm_scan_f32<threads, 16 , 8 ><<<blocks, threads, smem_size, stream>>> (
315
+ src0, src1, src2, src3, src4, src5, src6, dst,
316
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
317
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
318
+ break ;
319
+ default :
320
+ ssm_scan_f32<threads, 16 , 0 ><<<blocks, threads, smem_size, stream>>> (
321
+ src0, src1, src2, src3, src4, src5, src6, dst,
322
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
323
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
324
+ break ;
325
+ }
244
326
} else {
245
327
GGML_ABORT (" doesn't support d_state!=16." );
246
328
}
0 commit comments