1+ #include " ssm_conv.cuh"
2+
3+ template <int block_size>
4+ static __global__ void ssm_conv_f32 (const float *__restrict__ src0,
5+ const float *__restrict__ src1,
6+ const int src0_nb0, const int src0_nb1,
7+ const int src0_nb2, const int src1_nb1,
8+ float *__restrict__ dst, const int dst_nb0,
9+ const int dst_nb1, const int dst_nb2,
10+ const int nc, const int ncs, const int nr,
11+ const int n_t , const int n_s) {
12+ const int tid = blockIdx .y ;
13+ const int i3 = blockIdx .x ;
14+ const int i2 = threadIdx .x ;
15+
16+ const int ith = tid;
17+ const int nth = WARP_SIZE;
18+
19+ // rows per thread
20+ const int dr = (nr + nth - 1 ) / nth;
21+
22+ // row range for this thread
23+ const int ir0 = dr * ith;
24+ const int ir1 = min (ir0 + dr, nr);
25+ const int ir = ir1 - ir0;
26+
27+ // {d_conv - 1 + n_t, d_inner, n_seqs}
28+ // sliding window
29+ const float *s =
30+ (const float *)((const char *)src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
31+ i3 * src0_nb2); // {d_conv, d_inner, n_s}
32+ const float *c = (const float *)((const char *)src1 +
33+ ir0 * src1_nb1); // {d_conv, d_inner}
34+ float *x = (float *)((char *)dst + ir0 * dst_nb0 + i2 * dst_nb1 +
35+ i3 * dst_nb2); // {d_inner, n_t, n_s}
36+
37+ // TODO: transpose the output for smaller strides for big batches?
38+ // d_inner
39+ for (int i1 = 0 ; i1 < ir; ++i1) {
40+ // rowwise dot product
41+ // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
42+ float sumf = 0 .0f ;
43+
44+ // d_conv
45+ #pragma unroll
46+ for (int i0 = 0 ; i0 < nc; ++i0) {
47+ sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
48+ }
49+ x[i1] = sumf;
50+ }
51+ }
52+
53+ static void ssm_conv_f32_cuda (const float *src0, const float *src1,
54+ const int src0_nb0, const int src0_nb1,
55+ const int src0_nb2, const int src1_nb1,
56+ float *dst, const int dst_nb0, const int dst_nb1,
57+ const int dst_nb2, const int nc, const int ncs,
58+ const int nr, const int n_t , const int n_s,
59+ cudaStream_t stream) {
60+ const dim3 block_dims (n_t , 1 , 1 );
61+ // const int nblocks = n_s; // TODO
62+ const dim3 grid_dims (n_s, WARP_SIZE, 1 );
63+
64+ ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0 , stream>>> (
65+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1,
66+ dst_nb2, nc, ncs, nr, n_t , n_s);
67+ }
68+
69+ void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context &ctx, ggml_tensor *dst) {
70+ const struct ggml_tensor *src0 = dst->src [0 ]; // conv_x
71+ const struct ggml_tensor *src1 = dst->src [1 ]; // conv1d.weight
72+
73+ const int nc = src1->ne [0 ]; // d_conv
74+ const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
75+ const int nr = src0->ne [1 ]; // d_inner
76+ const int n_t = dst->ne [1 ]; // tokens per sequence
77+ const int n_s = dst->ne [2 ]; // number of sequences in the batch
78+
79+ GGML_ASSERT (dst->ne [0 ] == nr);
80+ GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
81+ GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
82+ GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
83+
84+ const float *src0_d = (const float *)src0->data ;
85+ const float *src1_d = (const float *)src1->data ;
86+ float *dst_d = (float *)dst->data ;
87+ cudaStream_t stream = ctx.stream ();
88+
89+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
90+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
91+ ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ],
92+ src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], nc,
93+ ncs, nr, n_t , n_s, stream);
94+ }
0 commit comments