Skip to content

Commit 79c1160

Browse files
authored
cuda: refactored ssm_scan and use CUB (#13291)
* cuda: refactored ssm_scan to use CUB * fixed compilation error when when not using CUB * assign L to constant and use size_t instead of int * deduplicated functions * change min blocks per mp to 1 * Use cub load and store warp transpose * suppress clang warning
1 parent 34c9d76 commit 79c1160

File tree

1 file changed

+152
-70
lines changed

1 file changed

+152
-70
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 152 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,117 @@
1+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2+
#define USE_CUB
3+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4+
5+
#ifdef USE_CUB
6+
#include <cub/cub.cuh>
7+
using namespace cub;
8+
#endif // USE_CUB
9+
110
#include "ssm-scan.cuh"
211

3-
template <size_t splitD, size_t N>
4-
__global__ void __launch_bounds__(splitD, 2)
5-
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6-
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
12+
// We would like to keep pragma unroll for cases where L_template is not 0,
13+
// so we suppress the clang transformation warning.
14+
#ifdef __clang__
15+
#pragma clang diagnostic push
16+
#pragma clang diagnostic ignored "-Wpass-failed"
17+
#endif // __clang__
18+
template <size_t splitD, size_t N, size_t L_template>
19+
__global__ void __launch_bounds__(splitD, 1)
20+
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21+
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
722
const int32_t * __restrict__ src6, float * __restrict__ dst,
823
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
924
const int src2_nb1, const int src2_nb2, const int src3_nb1,
1025
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11-
const int64_t s_off, const int64_t d_inner, const int64_t L) {
12-
13-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
14-
const int bidx = blockIdx.x; // split along B (sequences)
15-
const int bidy = blockIdx.y; // split along D (d_inner)
16-
const int tid = threadIdx.x;
17-
const int wid = tid / 32;
18-
const int wtid = tid % 32;
19-
20-
extern __shared__ float smem[];
21-
const int stride_sA = N + 1;
22-
const int stride_ss0 = N + 1;
23-
float * smem_A = smem;
24-
float * smem_s0 = smem_A + splitD * stride_sA;
25-
26-
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27-
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
28-
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
29-
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
30-
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31-
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32-
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
33-
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
34-
35-
const int stride_s0 = src0_nb2 / sizeof(float);
36-
const int stride_x = src1_nb2 / sizeof(float);
26+
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27+
{
28+
const size_t L = L_template == 0 ? L_param : L_template;
29+
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
30+
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
31+
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
32+
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
33+
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
34+
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
35+
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
36+
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
37+
38+
const int stride_x = src1_nb2 / sizeof(float);
3739
const int stride_dt = src2_nb1 / sizeof(float);
38-
const int stride_A = src3_nb1 / sizeof(float);
39-
const int stride_B = src4_nb2 / sizeof(float);
40-
const int stride_C = src5_nb2 / sizeof(float);
41-
const int stride_s = stride_s0;
42-
const int stride_y = d_inner;
40+
const int stride_B = src4_nb2 / sizeof(float);
41+
const int stride_C = src5_nb2 / sizeof(float);
42+
const int stride_y = d_inner;
4343

44-
// can N not be 16? for example 32?
45-
if (N == 16) {
46-
#pragma unroll
47-
for (size_t i = 0; i < splitD / 4; i += 2) {
48-
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49-
// todo: bank conflict
50-
// I am always confused with how to use the swizzling method to solve
51-
// bank conflit. Hoping somebody can tell me.
52-
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53-
}
44+
float regA[N];
45+
float regs0[N];
46+
47+
__shared__ float smemB[N];
48+
__shared__ float smemC[N];
49+
50+
#ifdef USE_CUB
51+
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52+
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53+
54+
union CubTempStorage {
55+
typename BlockLoad::TempStorage load_temp;
56+
typename BlockStore::TempStorage store_temp;
57+
};
58+
__shared__ CubTempStorage cub_temp_storage;
59+
60+
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
61+
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
62+
#else
63+
const int stride_s0 = src0_nb2 / sizeof(float);
64+
const int stride_A = src3_nb1 / sizeof(float);
5465
#pragma unroll
55-
for (size_t i = 0; i < splitD / 4; i += 2) {
56-
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57-
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
58-
}
66+
for (size_t n = 0; n < N; ++n)
67+
{
68+
regA[n] = A_block[threadIdx.x * stride_A + n];
69+
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
5970
}
71+
#endif
6072

61-
__syncthreads();
73+
#pragma unroll
74+
for (size_t i = 0; i < L; i++)
75+
{
76+
if (threadIdx.x < N)
77+
{
78+
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
79+
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
80+
}
81+
__syncthreads();
6282

63-
for (int64_t i = 0; i < L; i++) {
64-
float dt_soft_plus = dt_block[i * stride_dt + tid];
65-
if (dt_soft_plus <= 20.0f) {
66-
dt_soft_plus = log1pf(exp(dt_soft_plus));
83+
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
84+
if (dt_soft_plus <= 20.0f)
85+
{
86+
dt_soft_plus = log1pf(expf(dt_soft_plus));
6787
}
68-
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
88+
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
89+
6990
float sumf = 0.0f;
7091
#pragma unroll
71-
for (size_t j = 0; j < N; j++) {
72-
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
73-
(B_block[i * stride_B + j] * x_dt);
74-
sumf += state * C_block[i * stride_C + j];
75-
if (i == L - 1) {
76-
s_block[tid * stride_s + j] = state;
77-
} else {
78-
smem_s0[tid * stride_ss0 + j] = state;
79-
}
92+
for (size_t n = 0; n < N; n++)
93+
{
94+
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95+
sumf += state * smemC[n];
96+
regs0[n] = state;
8097
}
81-
__syncthreads();
82-
y_block[i * stride_y + tid] = sumf;
98+
y_block[i * stride_y + threadIdx.x] = sumf;
8399
}
100+
101+
#ifdef USE_CUB
102+
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
103+
#else
104+
const int stride_s = stride_s0;
105+
#pragma unroll
106+
for (size_t n = 0; n < N; ++n)
107+
{
108+
s_block[threadIdx.x * stride_s + n] = regs0[n];
109+
}
110+
#endif
84111
}
112+
#ifdef __clang__
113+
#pragma clang diagnostic pop
114+
#endif // __clang__
85115

86116
// assumes as many threads as d_state
87117
template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201231
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202232
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203233
cudaStream_t stream) {
234+
const int threads = 128;
204235
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205236
if (src3_nb1 == sizeof(float)) {
206237
// Mamba-2
207238
if (d_state == 128) {
208-
const int threads = 128;
209239
GGML_ASSERT(d_state % threads == 0);
210240
// NOTE: can be any power of two between 4 and 64
211241
const int splitH = 16;
@@ -229,18 +259,70 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
229259
GGML_ABORT("doesn't support d_state!=(128 or 256).");
230260
}
231261
} else {
232-
const int threads = 128;
233262
// Mamba-1
234263
GGML_ASSERT(n_head % threads == 0);
235264
GGML_ASSERT(head_dim == 1);
236265
GGML_ASSERT(n_group == 1);
237266
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
238267
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
239268
if (d_state == 16) {
240-
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
241-
src0, src1, src2, src3, src4, src5, src6, dst,
269+
switch (n_tok)
270+
{
271+
case 1:
272+
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
273+
src0, src1, src2, src3, src4, src5, src6, dst,
274+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
275+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
276+
break;
277+
case 2:
278+
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
279+
src0, src1, src2, src3, src4, src5, src6, dst,
280+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
281+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
282+
break;
283+
case 3:
284+
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
285+
src0, src1, src2, src3, src4, src5, src6, dst,
286+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
287+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
288+
break;
289+
case 4:
290+
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
291+
src0, src1, src2, src3, src4, src5, src6, dst,
292+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
293+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
294+
break;
295+
case 5:
296+
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
297+
src0, src1, src2, src3, src4, src5, src6, dst,
242298
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
243299
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
300+
break;
301+
case 6:
302+
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
303+
src0, src1, src2, src3, src4, src5, src6, dst,
304+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
305+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
306+
break;
307+
case 7:
308+
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
309+
src0, src1, src2, src3, src4, src5, src6, dst,
310+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
311+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
312+
break;
313+
case 8:
314+
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
315+
src0, src1, src2, src3, src4, src5, src6, dst,
316+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
317+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
318+
break;
319+
default:
320+
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
321+
src0, src1, src2, src3, src4, src5, src6, dst,
322+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
323+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
324+
break;
325+
}
244326
} else {
245327
GGML_ABORT("doesn't support d_state!=16.");
246328
}

0 commit comments

Comments
 (0)