11#include " ssm_conv.cuh"
22
33template <int block_size>
4- static __global__ void ssm_conv_f32 (const float *__restrict__ src0,
5- const float *__restrict__ src1,
6- const int src0_nb0, const int src0_nb1,
7- const int src0_nb2, const int src1_nb1,
8- float *__restrict__ dst, const int dst_nb0,
9- const int dst_nb1, const int dst_nb2,
10- const int nc, const int ncs, const int nr,
11- const int n_t , const int n_s) {
12- const int tid = blockIdx .y ;
13- const int i3 = blockIdx .x ;
14- const int i2 = threadIdx .x ;
4+ static __global__ void ssm_conv_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
5+ const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
6+ float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7+ const int nc, const int ncs, const int nr, const int n_t , const int n_s) {
8+ const int tid = blockIdx .y ;
9+ const int i3 = blockIdx .x ;
10+ const int i2 = threadIdx .x ;
1511
16- const int ith = tid;
17- const int nth = WARP_SIZE;
12+ const int ith = tid;
13+ const int nth = WARP_SIZE;
1814
19- // rows per thread
20- const int dr = (nr + nth - 1 ) / nth;
15+ // rows per thread
16+ const int dr = (nr + nth - 1 ) / nth;
2117
22- // row range for this thread
23- const int ir0 = dr * ith;
24- const int ir1 = min (ir0 + dr, nr);
25- const int ir = ir1 - ir0;
18+ // row range for this thread
19+ const int ir0 = dr * ith;
20+ const int ir1 = min (ir0 + dr, nr);
21+ const int ir = ir1 - ir0;
2622
27- // {d_conv - 1 + n_t, d_inner, n_seqs}
28- // sliding window
29- const float *s =
30- (const float *)((const char *)src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
31- i3 * src0_nb2); // {d_conv, d_inner, n_s}
32- const float *c = (const float *)((const char *)src1 +
33- ir0 * src1_nb1); // {d_conv, d_inner}
34- float *x = (float *)((char *)dst + ir0 * dst_nb0 + i2 * dst_nb1 +
35- i3 * dst_nb2); // {d_inner, n_t, n_s}
23+ // {d_conv - 1 + n_t, d_inner, n_seqs}
24+ // sliding window
25+ const float * s = (const float *) ((const char *) src0 + ir0 * src0_nb1 + i2 * src0_nb0 +
26+ i3 * src0_nb2); // {d_conv, d_inner, n_s}
27+ const float * c = (const float *) ((const char *) src1 + ir0 * src1_nb1); // {d_conv, d_inner}
28+ float * x = (float *) ((char *) dst + ir0 * dst_nb0 + i2 * dst_nb1 + i3 * dst_nb2); // {d_inner, n_t, n_s}
3629
37- // TODO: transpose the output for smaller strides for big batches?
38- // d_inner
39- for (int i1 = 0 ; i1 < ir; ++i1) {
40- // rowwise dot product
41- // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
42- float sumf = 0 .0f ;
30+ // TODO: transpose the output for smaller strides for big batches?
31+ // d_inner
32+ for (int i1 = 0 ; i1 < ir; ++i1) {
33+ // rowwise dot product
34+ // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
35+ float sumf = 0 .0f ;
4336
4437// d_conv
4538#pragma unroll
46- for (int i0 = 0 ; i0 < nc; ++i0) {
47- sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
39+ for (int i0 = 0 ; i0 < nc; ++i0) {
40+ sumf += s[i0 + i1 * ncs] * c[i0 + i1 * nc];
41+ }
42+ x[i1] = sumf;
4843 }
49- x[i1] = sumf;
50- }
5144}
5245
53- static void ssm_conv_f32_cuda (const float *src0, const float *src1,
54- const int src0_nb0, const int src0_nb1,
55- const int src0_nb2, const int src1_nb1,
56- float *dst, const int dst_nb0, const int dst_nb1,
57- const int dst_nb2, const int nc, const int ncs,
58- const int nr, const int n_t , const int n_s,
59- cudaStream_t stream) {
60- const dim3 block_dims (n_t , 1 , 1 );
61- // const int nblocks = n_s; // TODO
62- const dim3 grid_dims (n_s, WARP_SIZE, 1 );
46+ static void ssm_conv_f32_cuda (const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
47+ const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
48+ const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t ,
49+ const int n_s, cudaStream_t stream) {
50+ const dim3 block_dims (n_t , 1 , 1 );
51+ // const int nblocks = n_s; // TODO
52+ const dim3 grid_dims (n_s, WARP_SIZE, 1 );
6353
64- ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0 , stream>>> (
65- src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1,
66- dst_nb2, nc, ncs, nr, n_t , n_s);
54+ ssm_conv_f32<WARP_SIZE><<<grid_dims, block_dims, 0 , stream>>> (
55+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t , n_s);
6756}
6857
69- void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context &ctx, ggml_tensor *dst) {
70- const struct ggml_tensor *src0 = dst->src [0 ]; // conv_x
71- const struct ggml_tensor *src1 = dst->src [1 ]; // conv1d.weight
58+ void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
59+ const struct ggml_tensor * src0 = dst->src [0 ]; // conv_x
60+ const struct ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight
7261
73- const int nc = src1->ne [0 ]; // d_conv
74- const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
75- const int nr = src0->ne [1 ]; // d_inner
76- const int n_t = dst->ne [1 ]; // tokens per sequence
77- const int n_s = dst->ne [2 ]; // number of sequences in the batch
62+ const int nc = src1->ne [0 ]; // d_conv
63+ const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
64+ const int nr = src0->ne [1 ]; // d_inner
65+ const int n_t = dst->ne [1 ]; // tokens per sequence
66+ const int n_s = dst->ne [2 ]; // number of sequences in the batch
7867
79- GGML_ASSERT (dst->ne [0 ] == nr);
80- GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
81- GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
82- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
68+ GGML_ASSERT (dst->ne [0 ] == nr);
69+ GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
70+ GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
71+ GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
8372
84- const float *src0_d = (const float *)src0->data ;
85- const float *src1_d = (const float *)src1->data ;
86- float *dst_d = (float *)dst->data ;
87- cudaStream_t stream = ctx.stream ();
73+ const float * src0_d = (const float *) src0->data ;
74+ const float * src1_d = (const float *) src1->data ;
75+ float * dst_d = (float *) dst->data ;
76+ cudaStream_t stream = ctx.stream ();
8877
89- GGML_ASSERT (src0->type == GGML_TYPE_F32);
90- GGML_ASSERT (dst->type == GGML_TYPE_F32);
91- ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ],
92- src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], nc,
93- ncs, nr, n_t , n_s, stream);
94- }
78+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
79+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
80+ ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ],
81+ dst->nb [2 ], nc, ncs, nr, n_t , n_s, stream);
82+ }
0 commit comments