Skip to content

Commit 65180fb

Browse files
committed
faster ssm_scan
1 parent 9f91251 commit 65180fb

File tree

5 files changed

+431
-1
lines changed

5 files changed

+431
-1
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "ggml-cuda/rope.cuh"
3232
#include "ggml-cuda/scale.cuh"
3333
#include "ggml-cuda/softmax.cuh"
34+
#include "ggml-cuda/ssm_conv.cuh"
35+
#include "ggml-cuda/ssm_scan.cuh"
3436
#include "ggml-cuda/sum.cuh"
3537
#include "ggml-cuda/sumrows.cuh"
3638
#include "ggml-cuda/tsembd.cuh"
@@ -2155,6 +2157,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21552157
case GGML_OP_SUM_ROWS:
21562158
ggml_cuda_op_sum_rows(ctx, dst);
21572159
break;
2160+
case GGML_OP_SSM_CONV:
2161+
ggml_cuda_op_ssm_conv(ctx, dst);
2162+
break;
2163+
case GGML_OP_SSM_SCAN:
2164+
ggml_cuda_op_ssm_scan(ctx, dst);
2165+
break;
21582166
case GGML_OP_ARGSORT:
21592167
ggml_cuda_op_argsort(ctx, dst);
21602168
break;
@@ -2989,7 +2997,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29892997
case GGML_OP_SIN:
29902998
case GGML_OP_COS:
29912999
case GGML_OP_CLAMP:
2992-
return true;
3000+
case GGML_OP_SSM_SCAN:
3001+
case GGML_OP_SSM_CONV:
3002+
return true;
29933003
case GGML_OP_CONT:
29943004
return op->src[0]->type != GGML_TYPE_BF16;
29953005
case GGML_OP_DIAG_MASK_INF:

ggml/src/ggml-cuda/ssm_conv.cu

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "ssm_conv.cuh"
2+
3+
template <int block_size>
4+
static __global__ void ssm_conv_f32(const float *__restrict__ src0,
5+
const float *__restrict__ src1,
6+
const int src0_nb0, const int src0_nb1,
7+
const int src0_nb2, const int src1_nb1,
8+
float *__restrict__ dst, const int dst_nb0,
9+
const int dst_nb1, const int dst_nb2,
10+
const int nc, const int ncs, const int nr,
11+
const int n_t, const int n_s) {
12+
const int tid = blockIdx.y;
13+
const int i3 = blockIdx.x;
14+
const int i2 = threadIdx.x;
15+
16+
const int ith = tid;
17+
const int nth = WARP_SIZE;
18+
19+
// rows per thread
20+
const int dr = (nr + nth - 1) / nth;
21+
22+
// row range for this thread
23+
const int ir0 = dr * ith;
24+
const int ir1 = min(ir0 + dr, nr);
25+
const int ir = ir1 - ir0;
26+
27+
// {d_conv - 1 + n_t, d_inner, n_seqs}
28+
// sliding window
29+
const float *s =
30+
(const float *)((const char *)src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
31+
i3 * src0_nb2); // {d_conv, d_inner, n_s}
32+
const float *c = (const float *)((const char *)src1 +
33+
ir0 * src1_nb1); // {d_conv, d_inner}
34+
float *x = (float *)((char *)dst + ir0 * dst_nb0 + i2 * dst_nb1 +
35+
i3 * dst_nb2); // {d_inner, n_t, n_s}
36+
37+
// TODO: transpose the output for smaller strides for big batches?
38+
// d_inner
39+
for (int i1 = 0; i1 < ir; ++i1) {
40+
// rowwise dot product
41+
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
42+
float sumf = 0.0f;
43+
44+
// d_conv
45+
#pragma unroll
46+
for (int i0 = 0; i0 < nc; ++i0) {
47+
sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
48+
}
49+
x[i1] = sumf;
50+
}
51+
}
52+
53+
static void ssm_conv_f32_cuda(const float *src0, const float *src1,
54+
const int src0_nb0, const int src0_nb1,
55+
const int src0_nb2, const int src1_nb1,
56+
float *dst, const int dst_nb0, const int dst_nb1,
57+
const int dst_nb2, const int nc, const int ncs,
58+
const int nr, const int n_t, const int n_s,
59+
cudaStream_t stream) {
60+
const dim3 block_dims(n_t, 1, 1);
61+
// const int nblocks = n_s; // TODO
62+
const dim3 grid_dims(n_s, WARP_SIZE, 1);
63+
64+
ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0, stream>>>(
65+
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1,
66+
dst_nb2, nc, ncs, nr, n_t, n_s);
67+
}
68+
69+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context &ctx, ggml_tensor *dst) {
70+
const struct ggml_tensor *src0 = dst->src[0]; // conv_x
71+
const struct ggml_tensor *src1 = dst->src[1]; // conv1d.weight
72+
73+
const int nc = src1->ne[0]; // d_conv
74+
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
75+
const int nr = src0->ne[1]; // d_inner
76+
const int n_t = dst->ne[1]; // tokens per sequence
77+
const int n_s = dst->ne[2]; // number of sequences in the batch
78+
79+
GGML_ASSERT(dst->ne[0] == nr);
80+
GGML_ASSERT(src0->nb[0] == sizeof(float));
81+
GGML_ASSERT(src1->nb[0] == sizeof(float));
82+
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
83+
84+
const float *src0_d = (const float *)src0->data;
85+
const float *src1_d = (const float *)src1->data;
86+
float *dst_d = (float *)dst->data;
87+
cudaStream_t stream = ctx.stream();
88+
89+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
90+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
91+
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2],
92+
src1->nb[1], dst_d, dst->nb[0], dst->nb[1], dst->nb[2], nc,
93+
ncs, nr, n_t, n_s, stream);
94+
}

ggml/src/ggml-cuda/ssm_conv.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context& ctx, ggml_tensor* dst);

0 commit comments

Comments
 (0)