@@ -107,17 +107,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
107
107
if (nc == 4 ) {
108
108
ssm_conv_f32<threads, 4 ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
109
109
dst, dst_nb0, dst_nb1, dst_nb2, n_t );
110
+ } else if (nc == 3 ) {
111
+ ssm_conv_f32<threads, 3 ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
112
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t );
110
113
} else {
111
- GGML_ABORT (" Only support kernel size = 4 now." );
114
+ GGML_ABORT (" Only support kernel size = 3 or size = 4 right now." );
112
115
}
113
116
} else {
114
117
if (nc == 4 ) {
115
118
const int64_t split_n_t = 32 ;
116
119
dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
117
120
ssm_conv_long_token_f32<threads, 4 , split_n_t ><<<blocks, threads, 0 , stream>>> (
118
121
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
122
+ } else if (nc == 3 ) {
123
+ const int64_t split_n_t = 32 ;
124
+ dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
125
+ ssm_conv_long_token_f32<threads, 3 , split_n_t ><<<blocks, threads, 0 , stream>>> (
126
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
119
127
} else {
120
- GGML_ABORT (" Only support kernel size = 4 right now." );
128
+ GGML_ABORT (" Only support kernel size = 3 or size = 4 right now." );
121
129
}
122
130
}
123
131
}
0 commit comments