@@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
44static  __global__  void  ssm_conv_f32 (const  float  * __restrict__  src0, const  float  * __restrict__  src1,
55                                    const  int  src0_nb0, const  int  src0_nb1, const  int  src0_nb2, const  int  src1_nb1,
66                                    float  * __restrict__  dst, const  int  dst_nb0, const  int  dst_nb1, const  int  dst_nb2,
7-                                     const  int  nc, const  int  ncs, const  int  nr, const  int  n_t , const  int  n_s) {
7+                                     const  int64_t  n_t ) {
8+     GGML_UNUSED (src0_nb0);
89    const  int  tid  = threadIdx .x ;
910    const  int  bidx = blockIdx .x ;
1011    const  int  bidy = blockIdx .y ;
1112
12-     const  float  * x_block = (const  float  *) ((char  *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
13-     const  float  * w_block = (const  float  *) ((char  *) src1 + bidy * split_d_inner * src1_nb1);
13+     const  float  * x_block = (const  float  *) ((const   char  *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
14+     const  float  * w_block = (const  float  *) ((const   char  *) src1 + bidy * split_d_inner * src1_nb1);
1415    float  *       y_block = (float  *) ((char  *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
1516
1617    const  int  stride_x = src0_nb1 / sizeof (float );
@@ -21,43 +22,42 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
2122    float  w[d_conv] = { 0 .0f  };
2223
2324#pragma  unroll
24-     for  (int  j = 0 ; j < d_conv; j++) {
25+     for  (size_t  j = 0 ; j < d_conv; j++) {
2526        w[j] = w_block[tid * stride_w + j];
2627    }
2728
28-     for  (int  i = 0 ; i < n_t ; i++) {
29+     for  (int64_t  i = 0 ; i < n_t ; i++) {
2930        float  sumf = 0 .0f ;
3031
3132        if  (i == 0 ) {
32-             for  (int  j = 0 ; j < d_conv; j++) {
33+             for  (size_t  j = 0 ; j < d_conv; j++) {
3334                x[j] = x_block[tid * stride_x + j];
3435            }
3536        } else  {
3637            x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
3738        }
3839
3940#pragma  unroll
40-         for  (int  j = 0 ; j < d_conv; j++) {
41+         for  (size_t  j = 0 ; j < d_conv; j++) {
4142            sumf += x[(i + j) % d_conv] * w[j];
4243        }
4344        y_block[i * stride_y + tid] = sumf;
4445    }
4546}
4647
47- template  <size_t  split_d_inner, size_t  d_conv, size_t  split_n_t >
48+ template  <size_t  split_d_inner, size_t  d_conv, int64_t  split_n_t >
4849static  __global__  void  ssm_conv_long_token_f32 (const  float  * __restrict__  src0, const  float  * __restrict__  src1,
4950                                               const  int  src0_nb0, const  int  src0_nb1, const  int  src0_nb2,
5051                                               const  int  src1_nb1, float  * __restrict__  dst, const  int  dst_nb0,
51-                                                const  int  dst_nb1, const  int  dst_nb2, const  int  nc, const  int  ncs,
52-                                                const  int  nr, const  int  n_t , const  int  n_s) {
52+                                                const  int  dst_nb1, const  int  dst_nb2, const  int64_t  n_t ) {
5353    const  int  tid  = threadIdx .x ;
5454    const  int  bidx = blockIdx .x ;
5555    const  int  bidy = blockIdx .y ;
5656    const  int  bidz = blockIdx .z ;
5757
58-     const  float  * x_block = (const  float  *) ((char  *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
58+     const  float  * x_block = (const  float  *) ((const   char  *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
5959                                             bidz * split_n_t  * src0_nb0);
60-     const  float  * w_block = (const  float  *) ((char  *) src1 + bidy * split_d_inner * src1_nb1);
60+     const  float  * w_block = (const  float  *) ((const   char  *) src1 + bidy * split_d_inner * src1_nb1);
6161    float  *       y_block =
6262        (float  *) ((char  *) dst + bidx * dst_nb2 + bidz * split_n_t  * dst_nb1 + bidy * split_d_inner * dst_nb0);
6363
@@ -69,25 +69,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
6969    float  w[d_conv] = { 0 .0f  };
7070
7171#pragma  unroll
72-     for  (int  j = 0 ; j < d_conv; j++) {
72+     for  (size_t  j = 0 ; j < d_conv; j++) {
7373        w[j] = w_block[tid * stride_w + j];
7474    }
7575
7676#pragma  unroll
77-     for  (int  i = 0 ; i < split_n_t ; i++) {
77+     for  (int64_t  i = 0 ; i < split_n_t ; i++) {
7878        if  (bidz * split_n_t  + i < n_t ) {
7979            float  sumf = 0 .0f ;
8080
8181            if  (i == 0 ) {
82-                 for  (int  j = 0 ; j < d_conv; j++) {
82+                 for  (size_t  j = 0 ; j < d_conv; j++) {
8383                    x[j] = x_block[tid * stride_x + j];
8484                }
8585            } else  {
8686                x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
8787            }
8888
8989#pragma  unroll
90-             for  (int  j = 0 ; j < d_conv; j++) {
90+             for  (size_t  j = 0 ; j < d_conv; j++) {
9191                sumf += x[(i + j) % d_conv] * w[j];
9292            }
9393            y_block[i * stride_y + tid] = sumf;
@@ -97,27 +97,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
9797
9898static  void  ssm_conv_f32_cuda (const  float  * src0, const  float  * src1, const  int  src0_nb0, const  int  src0_nb1,
9999                              const  int  src0_nb2, const  int  src1_nb1, float  * dst, const  int  dst_nb0, const  int  dst_nb1,
100-                               const  int  dst_nb2, const  int  nc, const  int  ncs,  const   int   nr, const  int  n_t ,
101-                               const  int  n_s, cudaStream_t stream) {
100+                               const  int  dst_nb2, const  int64_t  nc, const  int64_t   nr, const  int64_t  n_t ,
101+                               const  int64_t  n_s, cudaStream_t stream) {
102102    const  int  threads = 128 ;
103103    GGML_ASSERT (nr % threads == 0 );
104104
105105    if  (n_t  <= 32 ) {
106106        const  dim3  blocks (n_s, (nr + threads - 1 ) / threads, 1 );
107107        if  (nc == 4 ) {
108108            ssm_conv_f32<threads, 4 ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109-                                                                      dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t ,
110-                                                                      n_s);
109+                                                                      dst, dst_nb0, dst_nb1, dst_nb2, n_t );
111110        } else  {
112111            GGML_ABORT (" Only support kernel size = 4  now." 
113112        }
114113    } else  {
115114        if  (nc == 4 ) {
116-             const  int  split_n_t  = 32 ;
117-             dim3       blocks (n_s, (nr + threads - 1 ) / threads, (n_t  + split_n_t  - 1 ) / split_n_t );
118-             ssm_conv_long_token_f32<threads, 4 , split_n_t >
119-                 <<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
120-                                                  dst_nb1, dst_nb2, nc, ncs, nr, n_t , n_s);
115+             const  int64_t  split_n_t  = 32 ;
116+             dim3           blocks (n_s, (nr + threads - 1 ) / threads, (n_t  + split_n_t  - 1 ) / split_n_t );
117+             ssm_conv_long_token_f32<threads, 4 , split_n_t ><<<blocks, threads, 0 , stream>>> (
118+                 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
121119        } else  {
122120            GGML_ABORT (" Only support kernel size = 4 right now." 
123121        }
@@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128126    const  struct  ggml_tensor  * src0 = dst->src [0 ];  //  conv_x
129127    const  struct  ggml_tensor  * src1 = dst->src [1 ];  //  conv1d.weight
130128
131-     const  int  nc  = src1->ne [0 ];                    //  d_conv
132-     const  int  ncs = src0->ne [0 ];                    //  d_conv - 1 + n_t
133-     const  int  nr  = src0->ne [1 ];                    //  d_inner
134-     const  int  n_t  = dst->ne [1 ];                     //  tokens per sequence
135-     const  int  n_s = dst->ne [2 ];                     //  number of sequences in the batch
129+     const  int64_t  nc  = src1->ne [0 ];                //  d_conv
130+     const  int64_t  nr  = src0->ne [1 ];                //  d_inner
131+     const  int64_t  n_t  = dst->ne [1 ];                 //  tokens per sequence
132+     const  int64_t  n_s = dst->ne [2 ];                 //  number of sequences in the batch
136133
137134    GGML_ASSERT (dst->ne [0 ] == nr);
138135    GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
@@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
147144    GGML_ASSERT (src0->type  == GGML_TYPE_F32);
148145    GGML_ASSERT (dst->type  == GGML_TYPE_F32);
149146    ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ],
150-                       dst->nb [2 ], nc, ncs,  nr, n_t , n_s, stream);
147+                       dst->nb [2 ], nc, nr, n_t , n_s, stream);
151148}
0 commit comments