@@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
44static __global__ void ssm_conv_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
55 const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
66 float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7- const int nc, const int ncs, const int nr, const int n_t , const int n_s) {
7+ const int64_t n_t ) {
8+ (void ) src0_nb0;
89 const int tid = threadIdx .x ;
910 const int bidx = blockIdx .x ;
1011 const int bidy = blockIdx .y ;
1112
12- const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
13- const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
13+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
14+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
1415 float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
1516
1617 const int stride_x = src0_nb1 / sizeof (float );
@@ -21,43 +22,42 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
2122 float w[d_conv] = { 0 .0f };
2223
2324#pragma unroll
24- for (int j = 0 ; j < d_conv; j++) {
25+ for (size_t j = 0 ; j < d_conv; j++) {
2526 w[j] = w_block[tid * stride_w + j];
2627 }
2728
28- for (int i = 0 ; i < n_t ; i++) {
29+ for (int64_t i = 0 ; i < n_t ; i++) {
2930 float sumf = 0 .0f ;
3031
3132 if (i == 0 ) {
32- for (int j = 0 ; j < d_conv; j++) {
33+ for (size_t j = 0 ; j < d_conv; j++) {
3334 x[j] = x_block[tid * stride_x + j];
3435 }
3536 } else {
3637 x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
3738 }
3839
3940#pragma unroll
40- for (int j = 0 ; j < d_conv; j++) {
41+ for (size_t j = 0 ; j < d_conv; j++) {
4142 sumf += x[(i + j) % d_conv] * w[j];
4243 }
4344 y_block[i * stride_y + tid] = sumf;
4445 }
4546}
4647
47- template <size_t split_d_inner, size_t d_conv, size_t split_n_t >
48+ template <size_t split_d_inner, size_t d_conv, int64_t split_n_t >
4849static __global__ void ssm_conv_long_token_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
4950 const int src0_nb0, const int src0_nb1, const int src0_nb2,
5051 const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
51- const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
52- const int nr, const int n_t , const int n_s) {
52+ const int dst_nb1, const int dst_nb2, const int64_t n_t ) {
5353 const int tid = threadIdx .x ;
5454 const int bidx = blockIdx .x ;
5555 const int bidy = blockIdx .y ;
5656 const int bidz = blockIdx .z ;
5757
58- const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
58+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
5959 bidz * split_n_t * src0_nb0);
60- const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
60+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
6161 float * y_block =
6262 (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
6363
@@ -69,25 +69,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
6969 float w[d_conv] = { 0 .0f };
7070
7171#pragma unroll
72- for (int j = 0 ; j < d_conv; j++) {
72+ for (size_t j = 0 ; j < d_conv; j++) {
7373 w[j] = w_block[tid * stride_w + j];
7474 }
7575
7676#pragma unroll
77- for (int i = 0 ; i < split_n_t ; i++) {
77+ for (int64_t i = 0 ; i < split_n_t ; i++) {
7878 if (bidz * split_n_t + i < n_t ) {
7979 float sumf = 0 .0f ;
8080
8181 if (i == 0 ) {
82- for (int j = 0 ; j < d_conv; j++) {
82+ for (size_t j = 0 ; j < d_conv; j++) {
8383 x[j] = x_block[tid * stride_x + j];
8484 }
8585 } else {
8686 x[(i - 1 ) % d_conv] = x_block[tid * stride_x + i + d_conv - 1 ];
8787 }
8888
8989#pragma unroll
90- for (int j = 0 ; j < d_conv; j++) {
90+ for (size_t j = 0 ; j < d_conv; j++) {
9191 sumf += x[(i + j) % d_conv] * w[j];
9292 }
9393 y_block[i * stride_y + tid] = sumf;
@@ -97,27 +97,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
9797
9898static void ssm_conv_f32_cuda (const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
9999 const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
100- const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t ,
101- const int n_s, cudaStream_t stream) {
100+ const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t ,
101+ const int64_t n_s, cudaStream_t stream) {
102102 const int threads = 128 ;
103103 GGML_ASSERT (nr % threads == 0 );
104104
105105 if (n_t <= 32 ) {
106106 const dim3 blocks (n_s, (nr + threads - 1 ) / threads, 1 );
107107 if (nc == 4 ) {
108108 ssm_conv_f32<threads, 4 ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109- dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t ,
110- n_s);
109+ dst, dst_nb0, dst_nb1, dst_nb2, n_t );
111110 } else {
112111 GGML_ABORT (" Only support kernel size = 4 now." );
113112 }
114113 } else {
115114 if (nc == 4 ) {
116- const int split_n_t = 32 ;
117- dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
118- ssm_conv_long_token_f32<threads, 4 , split_n_t >
119- <<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
120- dst_nb1, dst_nb2, nc, ncs, nr, n_t , n_s);
115+ const int64_t split_n_t = 32 ;
116+ dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
117+ ssm_conv_long_token_f32<threads, 4 , split_n_t ><<<blocks, threads, 0 , stream>>> (
118+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
121119 } else {
122120 GGML_ABORT (" Only support kernel size = 4 right now." );
123121 }
@@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128126 const struct ggml_tensor * src0 = dst->src [0 ]; // conv_x
129127 const struct ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight
130128
131- const int nc = src1->ne [0 ]; // d_conv
132- const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
133- const int nr = src0->ne [1 ]; // d_inner
134- const int n_t = dst->ne [1 ]; // tokens per sequence
135- const int n_s = dst->ne [2 ]; // number of sequences in the batch
129+ const int64_t nc = src1->ne [0 ]; // d_conv
130+ const int64_t nr = src0->ne [1 ]; // d_inner
131+ const int64_t n_t = dst->ne [1 ]; // tokens per sequence
132+ const int64_t n_s = dst->ne [2 ]; // number of sequences in the batch
136133
137134 GGML_ASSERT (dst->ne [0 ] == nr);
138135 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
@@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
147144 GGML_ASSERT (src0->type == GGML_TYPE_F32);
148145 GGML_ASSERT (dst->type == GGML_TYPE_F32);
149146 ssm_conv_f32_cuda (src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, dst->nb [0 ], dst->nb [1 ],
150- dst->nb [2 ], nc, ncs, nr, n_t , n_s, stream);
147+ dst->nb [2 ], nc, nr, n_t , n_s, stream);
151148}
0 commit comments