@@ -8918,7 +8918,8 @@ void ggml_compute_forward_flash_attn_back(
89188918
89198919// ggml_compute_forward_ssm_conv
89208920
8921- static void ggml_compute_forward_ssm_conv_f32 (
8921+ template <typename src_t , typename conv_t >
8922+ static void ggml_compute_forward_ssm_conv_impl (
89228923 const ggml_compute_params * params,
89238924 ggml_tensor * dst) {
89248925 const ggml_tensor * src0 = dst->src [0 ]; // conv_x
@@ -8934,9 +8935,10 @@ static void ggml_compute_forward_ssm_conv_f32(
89348935 const int n_s = dst->ne [2 ]; // number of sequences in the batch
89358936
89368937 GGML_ASSERT ( dst->ne [0 ] == nr);
8937- GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
8938- GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
8939- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (float ));
8938+ GGML_ASSERT (src0->nb [0 ] == sizeof (src_t ));
8939+ GGML_ASSERT (src1->nb [0 ] == sizeof (conv_t ));
8940+ GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (src_t ));
8941+ GGML_ASSERT (dst->type == src0->type );
89408942
89418943 // rows per thread
89428944 const int dr = (nr + nth - 1 )/nth;
@@ -8950,9 +8952,9 @@ static void ggml_compute_forward_ssm_conv_f32(
89508952 for (int i2 = 0 ; i2 < n_t ; ++i2) {
89518953 // {d_conv - 1 + n_t, d_inner, n_seqs}
89528954 // sliding window
8953- const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i2*(src0->nb [0 ]) + i3*(src0->nb [2 ])); // {d_conv, d_inner, n_s}
8954- const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb [1 ])); // {d_conv, d_inner}
8955- float * x = (float *) ((char *) dst->data + ir0*(dst->nb [0 ]) + i2*(dst->nb [1 ]) + i3*(dst->nb [2 ])); // {d_inner, n_t, n_s}
8955+ const src_t * s = (const src_t *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i2*(src0->nb [0 ]) + i3*(src0->nb [2 ])); // {d_conv, d_inner, n_s}
8956+ const conv_t * c = (const conv_t *) ((const char *) src1->data + ir0*(src1->nb [1 ])); // {d_conv, d_inner}
8957+ src_t * x = ( src_t *) (( char *) dst->data + ir0*(dst->nb [0 ]) + i2*(dst->nb [1 ]) + i3*(dst->nb [2 ])); // {d_inner, n_t, n_s}
89568958
89578959 // TODO: transpose the output for smaller strides for big batches?
89588960 // d_inner
@@ -8963,22 +8965,149 @@ static void ggml_compute_forward_ssm_conv_f32(
89638965
89648966 // d_conv
89658967 for (int i0 = 0 ; i0 < nc; ++i0) {
8966- sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
8968+ sumf += type_conversion_table< src_t >:: to_f32 ( s[i0 + i1*ncs]) * type_conversion_table< conv_t >:: to_f32 ( c[i0 + i1*nc]) ;
89678969 }
8968- x[i1] = sumf;
8969- }
8970- }
8971- }
8972- }
8970+ x[i1] = type_conversion_table<src_t >::from_f32 (sumf);
8971+ }
8972+ }
8973+ }
8974+ }
8975+
8976+ // static void ggml_compute_forward_ssm_conv_q_f32(
8977+ // const ggml_compute_params * params,
8978+ // ggml_tensor * dst) {
8979+ // const ggml_tensor * src0 = dst->src[0]; // conv_x
8980+ // const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
8981+
8982+ // const int ith = params->ith;
8983+ // const int nth = params->nth;
8984+
8985+ // const int nc = src1->ne[0]; // d_conv
8986+ // const int ncs = src0->ne[0]; // d_conv - 1 + n_t
8987+ // const int nr = src0->ne[1]; // d_inner
8988+ // const int n_t = dst->ne[1]; // tokens per sequence
8989+ // const int n_s = dst->ne[2]; // number of sequences in the batch
8990+
8991+ // const ggml_type type0 = src0->type;
8992+ // const size_t type0_size = ggml_type_size(type0);
8993+ // ggml_to_float_t const dequantize_row0_q = ggml_get_type_traits(type0)->to_float;
8994+ // ggml_from_float_t const quantize_row0_q = ggml_get_type_traits_cpu(type0)->from_float;
8995+
8996+ // const ggml_type type1 = src1->type;
8997+ // const size_t type1_size = ggml_type_size(type1);
8998+ // ggml_to_float_t const dequantize_row1_q = ggml_get_type_traits(type1)->to_float;
8999+ // ggml_from_float_t const quantize_row1_q = ggml_get_type_traits_cpu(type1)->from_float;
9000+
9001+ // GGML_ASSERT( dst->ne[0] == nr);
9002+ // GGML_ASSERT(src0->nb[0] == type0_size);
9003+ // GGML_ASSERT(src1->nb[0] == type1_size);
9004+ // GGML_ASSERT(src0->nb[1] == src0->ne[0]*type0_size);
9005+ // GGML_ASSERT(dst->type == src0->type);
9006+
9007+ // // rows per thread
9008+ // const int dr = (nr + nth - 1)/nth;
9009+
9010+ // // row range for this thread
9011+ // const int ir0 = dr*ith;
9012+ // const int ir1 = MIN(ir0 + dr, nr);
9013+ // const int ir = ir1 - ir0;
9014+
9015+ // // temporary storage for dequantized lines
9016+ // float * wdata = (float *) params->wdata + (src0->ne[0] + CACHE_LINE_SIZE_F32) * ith;
9017+
9018+ // for (int i3 = 0; i3 < n_s; ++i3) {
9019+ // for (int i2 = 0; i2 < n_t; ++i2) {
9020+ // // {d_conv - 1 + n_t, d_inner, n_seqs}
9021+ // // sliding window
9022+ // const void * s = (const void *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
9023+ // const void * c = (const void *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
9024+ // void * x = ( void *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
9025+
9026+ // // TODO: transpose the output for smaller strides for big batches?
9027+ // // d_inner
9028+ // for (int i1 = 0; i1 < ir; ++i1) {
9029+ // // rowwise dot product
9030+ // // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
9031+ // float sumf = 0.0f;
9032+
9033+ // // d_conv
9034+ // for (int i0 = 0; i0 < nc; ++i0) {
9035+ // sumf += type_conversion_table<T>::to_f32(s[i0 + i1*ncs]) * type_conversion_table<T>::to_f32(c[i0 + i1*nc]);
9036+ // }
9037+ // x[i1] = type_conversion_table<T>::from_f32(sumf);
9038+ // }
9039+ // }
9040+ // }
9041+ // }
89739042
89749043void ggml_compute_forward_ssm_conv (
89759044 const ggml_compute_params * params,
89769045 ggml_tensor * dst) {
89779046 switch (dst->src [0 ]->type ) {
89789047 case GGML_TYPE_F32:
89799048 {
8980- ggml_compute_forward_ssm_conv_f32 (params, dst);
9049+ switch (dst->src [1 ]->type ) {
9050+ case GGML_TYPE_F32:
9051+ {
9052+ ggml_compute_forward_ssm_conv_impl<float , float >(params, dst);
9053+ } break ;
9054+ case GGML_TYPE_F16:
9055+ {
9056+ ggml_compute_forward_ssm_conv_impl<float , ggml_fp16_t >(params, dst);
9057+ } break ;
9058+ case GGML_TYPE_BF16:
9059+ {
9060+ ggml_compute_forward_ssm_conv_impl<float , ggml_bf16_t >(params, dst);
9061+ } break ;
9062+ default :
9063+ {
9064+ GGML_ABORT (" fatal error" );
9065+ }
9066+ }
9067+ } break ;
9068+ case GGML_TYPE_F16:
9069+ {
9070+ switch (dst->src [1 ]->type ) {
9071+ case GGML_TYPE_F32:
9072+ {
9073+ ggml_compute_forward_ssm_conv_impl<ggml_fp16_t , float >(params, dst);
9074+ } break ;
9075+ case GGML_TYPE_F16:
9076+ {
9077+ ggml_compute_forward_ssm_conv_impl<ggml_fp16_t , ggml_fp16_t >(params, dst);
9078+ } break ;
9079+ case GGML_TYPE_BF16:
9080+ {
9081+ ggml_compute_forward_ssm_conv_impl<ggml_fp16_t , ggml_bf16_t >(params, dst);
9082+ } break ;
9083+ default :
9084+ {
9085+ GGML_ABORT (" fatal error" );
9086+ }
9087+ }
9088+ } break ;
9089+ case GGML_TYPE_BF16:
9090+ {
9091+ switch (dst->src [1 ]->type ) {
9092+ case GGML_TYPE_F32:
9093+ {
9094+ ggml_compute_forward_ssm_conv_impl<ggml_bf16_t , float >(params, dst);
9095+ } break ;
9096+ case GGML_TYPE_F16:
9097+ {
9098+ ggml_compute_forward_ssm_conv_impl<ggml_bf16_t , ggml_fp16_t >(params, dst);
9099+ } break ;
9100+ case GGML_TYPE_BF16:
9101+ {
9102+ ggml_compute_forward_ssm_conv_impl<ggml_bf16_t , ggml_bf16_t >(params, dst);
9103+ } break ;
9104+ default :
9105+ {
9106+ GGML_ABORT (" fatal error" );
9107+ }
9108+ }
89819109 } break ;
9110+ // TODO: Support quantized types
89829111 default :
89839112 {
89849113 GGML_ABORT (" fatal error" );
0 commit comments