-
Couldn't load subscription status.
- Fork 13.4k
ggml:Mamba Cuda kernel performance improve #9186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
f809568
cc365b0
25f9e65
64fbd32
12c913c
061e520
fae826f
20d390b
8dd323b
b423a6d
40f4787
1928967
21c16fa
e53b14f
eec0e8c
0e682ce
5999d6d
316a049
99f2ac1
63b6e73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| #include "ssm_conv.cuh" | ||
|
|
||
| template <int block_size> | ||
| static __global__ void ssm_conv_f32( | ||
| const float * src0, const float * src1, | ||
| const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb1, | ||
| float * dst, | ||
| const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
| const int nc, const int ncs, const int nr) { | ||
|
||
|
|
||
| // const int row = blockIdx.x*blockDim.y + threadIdx.y; | ||
| const int tid = threadIdx.x; | ||
| const int i2 = blockIdx.x; | ||
| const int i3 = threadIdx.y; | ||
|
||
|
|
||
| const int ith = tid; | ||
| const int nth = WARP_SIZE; | ||
|
|
||
| // rows per thread | ||
| const int dr = (nr + nth - 1)/nth; | ||
|
|
||
| // row range for this thread | ||
| const int ir0 = dr * ith; | ||
| const int ir1 = min(ir0 + dr, nr); | ||
| const int ir = ir1 - ir0; | ||
|
|
||
| // {d_conv - 1 + n_t, d_inner, n_seqs} | ||
| // sliding window | ||
| const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s} | ||
| const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner} | ||
| float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} | ||
| // TODO: transpose the output for smaller strides for big batches? | ||
| // d_inner | ||
| #pragma unroll | ||
| for (int i1 = 0; i1 < ir; ++i1) { | ||
|
||
| // rowwise dot product | ||
| // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision | ||
| float sumf = 0.0f; | ||
| #pragma unroll | ||
| for (int i0 = 0; i0 < nc; ++i0) { | ||
| sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with In practice, for Mamba (and Mamba-2) models, To coalesce memory accesses (at least for large prompts), I guess each warp could operate on I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good Idea,I am currently testing according to your method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Use one warp per iteration of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated. |
||
| } | ||
| x[i1] = sumf; | ||
| } | ||
| } | ||
|
|
||
| static void ssm_conv_f32_cuda( | ||
| const float * src0, const float * src1, | ||
| const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb1, | ||
| float * dst, | ||
| const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
| const int nc, const int ncs, const int nr, const int n_t, const int n_s, | ||
| cudaStream_t stream) { | ||
|
|
||
| const dim3 block_dims(WARP_SIZE, n_s, 1); | ||
| const int nblocks = n_t; | ||
|
|
||
| ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>( | ||
| src0, src1, | ||
| src0_nb0, src0_nb1, src0_nb2, | ||
| src1_nb1, | ||
| dst, | ||
| dst_nb0, dst_nb1, dst_nb2, | ||
| nc, ncs, nr); | ||
| } | ||
|
|
||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
| const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
|
|
||
| const int nc = src1->ne[0]; // d_conv | ||
| const int ncs = src0->ne[0]; // d_conv - 1 + n_t | ||
| const int nr = src0->ne[1]; // d_inner | ||
| const int n_t = dst->ne[1]; // tokens per sequence | ||
| const int n_s = dst->ne[2]; // number of sequences in the batch | ||
|
|
||
| GGML_ASSERT( dst->ne[0] == nr); | ||
| GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); | ||
|
|
||
| const float * src0_d = (const float *)src0->data; | ||
| const float * src1_d = (const float *)src1->data; | ||
| float * dst_d = (float *)dst->data; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
|
|
||
| ssm_conv_f32_cuda(src0_d, src1_d, | ||
| src0->nb[0], src0->nb[1], src0->nb[2], | ||
| src1->nb[1], | ||
| dst_d, | ||
| dst->nb[0], dst->nb[1], dst->nb[2], | ||
| nc, ncs, nr, n_t, n_s, | ||
| stream); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "common.cuh" | ||
|
|
||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My comments for |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| #include "ssm_scan.cuh" | ||
|
|
||
| template <int block_size> | ||
| static __global__ void ssm_scan_f32( | ||
| const float * src0, const float * src1, const float * src2, const float * src3, | ||
| const float * src4, const float * src5, | ||
| const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
| const int src2_nb0, const int src2_nb1, const int src2_nb2, | ||
| const int src3_nb1, | ||
| const int src4_nb1, const int src4_nb2, | ||
| const int src5_nb1, const int src5_nb2, | ||
| float * dst, | ||
| const int nc, const int nr) { | ||
|
|
||
| const int tid = threadIdx.x; | ||
| const int i2 = blockIdx.x; | ||
| const int i3 = threadIdx.y; | ||
|
|
||
| const int ith = tid; | ||
| const int nth = WARP_SIZE; | ||
|
|
||
| // rows per thread | ||
| const int dr = (nr + nth - 1)/nth; | ||
|
|
||
| // row range for this thread | ||
| const int ir0 = dr*ith; | ||
| const int ir1 = min(ir0 + dr, nr); | ||
| const int ir = ir1 - ir0; | ||
|
|
||
| const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} | ||
| const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} | ||
| const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} | ||
| const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} | ||
| const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} | ||
| const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} | ||
| float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} | ||
| float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} | ||
|
|
||
| // use the output as the source for the next token-wise iterations | ||
| if (i2 > 0) { s0 = s; } | ||
|
|
||
| // d_inner | ||
| #pragma unroll | ||
| for (int i1 = 0; i1 < ir; ++i1) { | ||
| // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 | ||
| float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; | ||
| float x_dt = x[i1] * dt_soft_plus; | ||
| float sumf = 0.0f; | ||
| // d_state | ||
| #pragma unroll | ||
| for (int i0 = 0; i0 < nc; ++i0) { | ||
| int i = i0 + i1*nc; | ||
| // state = prev_state * dA + dB * x | ||
| float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); | ||
| // y = rowwise_dotprod(state, C) | ||
| sumf += state * C[i0]; | ||
| s[i] = state; | ||
| } | ||
| y[i1] = sumf; | ||
| } | ||
| } | ||
|
|
||
| static void ssm_scan_f32_cuda( | ||
| const float * src0, const float * src1, const float * src2, const float * src3, | ||
| const float * src4, const float * src5, | ||
| const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, | ||
| const int src2_nb0, const int src2_nb1, const int src2_nb2, | ||
| const int src3_nb1, | ||
| const int src4_nb1, const int src4_nb2, | ||
| const int src5_nb1, const int src5_nb2, | ||
| float * dst, | ||
| const int nc, const int nr, const int n_t, const int n_s, | ||
| cudaStream_t stream) { | ||
|
|
||
| const dim3 block_dims(WARP_SIZE, n_s, 1); | ||
| const int nblocks = n_t; | ||
|
|
||
| ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>( | ||
| src0, src1, src2, src3, | ||
| src4, src5, | ||
| src0_nb1, src0_nb2, | ||
| src1_nb0, src1_nb1, src1_nb2, src1_nb3, | ||
| src2_nb0, src2_nb1, src2_nb2, | ||
| src3_nb1, | ||
| src4_nb1, src4_nb2, | ||
| src5_nb1, src5_nb2, | ||
| dst, | ||
| nc, nr); | ||
| } | ||
|
|
||
| void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const struct ggml_tensor * src0 = dst->src[0]; // s | ||
| const struct ggml_tensor * src1 = dst->src[1]; // x | ||
| const struct ggml_tensor * src2 = dst->src[2]; // dt | ||
| const struct ggml_tensor * src3 = dst->src[3]; // A | ||
| const struct ggml_tensor * src4 = dst->src[4]; // B | ||
| const struct ggml_tensor * src5 = dst->src[5]; // C | ||
|
|
||
| const int64_t nc = src0->ne[0]; // d_state | ||
| const int64_t nr = src0->ne[1]; // d_inner | ||
| const int64_t n_t = src1->ne[1]; // number of tokens per sequence | ||
| const int64_t n_s = src0->ne[2]; // number of sequences in the batch | ||
|
|
||
| GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); | ||
| GGML_ASSERT(src0->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src1->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src2->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src3->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src4->nb[0] == sizeof(float)); | ||
| GGML_ASSERT(src5->nb[0] == sizeof(float)); | ||
| // required for the dot product between s and C | ||
| GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); | ||
| // required for per-sequence offsets for states | ||
| GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); | ||
| // required to get correct offset for state destination (i.e. src1->nb[3]) | ||
| GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); | ||
|
|
||
| const float * src0_d = (const float *)src0->data; | ||
| const float * src1_d = (const float *)src1->data; | ||
| const float * src2_d = (const float *)src2->data; | ||
| const float * src3_d = (const float *)src3->data; | ||
| const float * src4_d = (const float *)src4->data; | ||
| const float * src5_d = (const float *)src5->data; | ||
| float * dst_d = (float *)dst->data; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
|
|
||
| ssm_scan_f32_cuda( | ||
| src0_d, src1_d, src2_d, src3_d, | ||
| src4_d, src5_d, | ||
| src0->nb[1], src0->nb[2], | ||
| src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], | ||
| src2->nb[0], src2->nb[1], src2->nb[2], | ||
| src3->nb[1], | ||
| src4->nb[1], src4->nb[2], | ||
| src5->nb[1], src5->nb[2], | ||
| dst_d, | ||
| nc, nr, n_t, n_s, | ||
| stream); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| #include "common.cuh" | ||
|
|
||
| void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
Uh oh!
There was an error while loading. Please reload this page.