Skip to content

Commit 82bba1d

Browse files
committed
feat(ggml-cpu): Add f16 and bf16 support for ssm_conv
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8b6f38a commit 82bba1d

File tree

1 file changed

+143
-14
lines changed

1 file changed

+143
-14
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 143 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8918,7 +8918,8 @@ void ggml_compute_forward_flash_attn_back(
89188918

89198919
// ggml_compute_forward_ssm_conv
89208920

8921-
static void ggml_compute_forward_ssm_conv_f32(
8921+
template<typename src_t, typename conv_t>
8922+
static void ggml_compute_forward_ssm_conv_impl(
89228923
const ggml_compute_params * params,
89238924
ggml_tensor * dst) {
89248925
const ggml_tensor * src0 = dst->src[0]; // conv_x
@@ -8934,9 +8935,10 @@ static void ggml_compute_forward_ssm_conv_f32(
89348935
const int n_s = dst->ne[2]; // number of sequences in the batch
89358936

89368937
GGML_ASSERT( dst->ne[0] == nr);
8937-
GGML_ASSERT(src0->nb[0] == sizeof(float));
8938-
GGML_ASSERT(src1->nb[0] == sizeof(float));
8939-
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8938+
GGML_ASSERT(src0->nb[0] == sizeof(src_t));
8939+
GGML_ASSERT(src1->nb[0] == sizeof(conv_t));
8940+
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(src_t));
8941+
GGML_ASSERT(dst->type == src0->type);
89408942

89418943
// rows per thread
89428944
const int dr = (nr + nth - 1)/nth;
@@ -8950,9 +8952,9 @@ static void ggml_compute_forward_ssm_conv_f32(
89508952
for (int i2 = 0; i2 < n_t; ++i2) {
89518953
// {d_conv - 1 + n_t, d_inner, n_seqs}
89528954
// sliding window
8953-
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
8954-
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
8955-
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
8955+
const src_t * s = (const src_t *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
8956+
const conv_t * c = (const conv_t *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
8957+
src_t * x = ( src_t *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
89568958

89578959
// TODO: transpose the output for smaller strides for big batches?
89588960
// d_inner
@@ -8963,22 +8965,149 @@ static void ggml_compute_forward_ssm_conv_f32(
89638965

89648966
// d_conv
89658967
for (int i0 = 0; i0 < nc; ++i0) {
8966-
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
8968+
sumf += type_conversion_table<src_t>::to_f32(s[i0 + i1*ncs]) * type_conversion_table<conv_t>::to_f32(c[i0 + i1*nc]);
89678969
}
8968-
x[i1] = sumf;
8969-
}
8970-
}
8971-
}
8972-
}
8970+
x[i1] = type_conversion_table<src_t>::from_f32(sumf);
8971+
}
8972+
}
8973+
}
8974+
}
8975+
8976+
// static void ggml_compute_forward_ssm_conv_q_f32(
8977+
// const ggml_compute_params * params,
8978+
// ggml_tensor * dst) {
8979+
// const ggml_tensor * src0 = dst->src[0]; // conv_x
8980+
// const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
8981+
8982+
// const int ith = params->ith;
8983+
// const int nth = params->nth;
8984+
8985+
// const int nc = src1->ne[0]; // d_conv
8986+
// const int ncs = src0->ne[0]; // d_conv - 1 + n_t
8987+
// const int nr = src0->ne[1]; // d_inner
8988+
// const int n_t = dst->ne[1]; // tokens per sequence
8989+
// const int n_s = dst->ne[2]; // number of sequences in the batch
8990+
8991+
// const ggml_type type0 = src0->type;
8992+
// const size_t type0_size = ggml_type_size(type0);
8993+
// ggml_to_float_t const dequantize_row0_q = ggml_get_type_traits(type0)->to_float;
8994+
// ggml_from_float_t const quantize_row0_q = ggml_get_type_traits_cpu(type0)->from_float;
8995+
8996+
// const ggml_type type1 = src1->type;
8997+
// const size_t type1_size = ggml_type_size(type1);
8998+
// ggml_to_float_t const dequantize_row1_q = ggml_get_type_traits(type1)->to_float;
8999+
// ggml_from_float_t const quantize_row1_q = ggml_get_type_traits_cpu(type1)->from_float;
9000+
9001+
// GGML_ASSERT( dst->ne[0] == nr);
9002+
// GGML_ASSERT(src0->nb[0] == type0_size);
9003+
// GGML_ASSERT(src1->nb[0] == type1_size);
9004+
// GGML_ASSERT(src0->nb[1] == src0->ne[0]*type0_size);
9005+
// GGML_ASSERT(dst->type == src0->type);
9006+
9007+
// // rows per thread
9008+
// const int dr = (nr + nth - 1)/nth;
9009+
9010+
// // row range for this thread
9011+
// const int ir0 = dr*ith;
9012+
// const int ir1 = MIN(ir0 + dr, nr);
9013+
// const int ir = ir1 - ir0;
9014+
9015+
// // temporary storage for dequantized lines
9016+
// float * wdata = (float *) params->wdata + (src0->ne[0] + CACHE_LINE_SIZE_F32) * ith;
9017+
9018+
// for (int i3 = 0; i3 < n_s; ++i3) {
9019+
// for (int i2 = 0; i2 < n_t; ++i2) {
9020+
// // {d_conv - 1 + n_t, d_inner, n_seqs}
9021+
// // sliding window
9022+
// const void * s = (const void *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
9023+
// const void * c = (const void *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
9024+
// void * x = ( void *) (( char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
9025+
9026+
// // TODO: transpose the output for smaller strides for big batches?
9027+
// // d_inner
9028+
// for (int i1 = 0; i1 < ir; ++i1) {
9029+
// // rowwise dot product
9030+
// // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
9031+
// float sumf = 0.0f;
9032+
9033+
// // d_conv
9034+
// for (int i0 = 0; i0 < nc; ++i0) {
9035+
// sumf += type_conversion_table<T>::to_f32(s[i0 + i1*ncs]) * type_conversion_table<T>::to_f32(c[i0 + i1*nc]);
9036+
// }
9037+
// x[i1] = type_conversion_table<T>::from_f32(sumf);
9038+
// }
9039+
// }
9040+
// }
9041+
// }
89739042

89749043
void ggml_compute_forward_ssm_conv(
89759044
const ggml_compute_params * params,
89769045
ggml_tensor * dst) {
89779046
switch (dst->src[0]->type) {
89789047
case GGML_TYPE_F32:
89799048
{
8980-
ggml_compute_forward_ssm_conv_f32(params, dst);
9049+
switch (dst->src[1]->type) {
9050+
case GGML_TYPE_F32:
9051+
{
9052+
ggml_compute_forward_ssm_conv_impl<float, float>(params, dst);
9053+
} break;
9054+
case GGML_TYPE_F16:
9055+
{
9056+
ggml_compute_forward_ssm_conv_impl<float, ggml_fp16_t>(params, dst);
9057+
} break;
9058+
case GGML_TYPE_BF16:
9059+
{
9060+
ggml_compute_forward_ssm_conv_impl<float, ggml_bf16_t>(params, dst);
9061+
} break;
9062+
default:
9063+
{
9064+
GGML_ABORT("fatal error");
9065+
}
9066+
}
9067+
} break;
9068+
case GGML_TYPE_F16:
9069+
{
9070+
switch (dst->src[1]->type) {
9071+
case GGML_TYPE_F32:
9072+
{
9073+
ggml_compute_forward_ssm_conv_impl<ggml_fp16_t, float>(params, dst);
9074+
} break;
9075+
case GGML_TYPE_F16:
9076+
{
9077+
ggml_compute_forward_ssm_conv_impl<ggml_fp16_t, ggml_fp16_t>(params, dst);
9078+
} break;
9079+
case GGML_TYPE_BF16:
9080+
{
9081+
ggml_compute_forward_ssm_conv_impl<ggml_fp16_t, ggml_bf16_t>(params, dst);
9082+
} break;
9083+
default:
9084+
{
9085+
GGML_ABORT("fatal error");
9086+
}
9087+
}
9088+
} break;
9089+
case GGML_TYPE_BF16:
9090+
{
9091+
switch (dst->src[1]->type) {
9092+
case GGML_TYPE_F32:
9093+
{
9094+
ggml_compute_forward_ssm_conv_impl<ggml_bf16_t, float>(params, dst);
9095+
} break;
9096+
case GGML_TYPE_F16:
9097+
{
9098+
ggml_compute_forward_ssm_conv_impl<ggml_bf16_t, ggml_fp16_t>(params, dst);
9099+
} break;
9100+
case GGML_TYPE_BF16:
9101+
{
9102+
ggml_compute_forward_ssm_conv_impl<ggml_bf16_t, ggml_bf16_t>(params, dst);
9103+
} break;
9104+
default:
9105+
{
9106+
GGML_ABORT("fatal error");
9107+
}
9108+
}
89819109
} break;
9110+
// TODO: Support quantized types
89829111
default:
89839112
{
89849113
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)