Skip to content

Commit 6256f9a

Browse files
committed
feat(ggml-cpu): Add partial implementation of scale for f16
This is used to zero-out the state in build_rs, so it's required to support F16 cache states for recurrent models. The bias route does not get hit in that case, but would need to be implemented if used elsewhere. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 7ad0f37 commit 6256f9a

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4558,6 +4558,60 @@ static void ggml_compute_forward_scale_f32(
45584558
}
45594559
}
45604560

4561+
static void ggml_compute_forward_scale_f16(
4562+
const ggml_compute_params * params,
4563+
ggml_tensor * dst) {
4564+
4565+
const ggml_tensor * src0 = dst->src[0];
4566+
4567+
GGML_ASSERT(ggml_is_contiguous(src0));
4568+
GGML_ASSERT(ggml_is_contiguous(dst));
4569+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
4570+
4571+
float s; // scale factor
4572+
float b; // bias
4573+
4574+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4575+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4576+
4577+
const int ith = params->ith;
4578+
const int nth = params->nth;
4579+
4580+
const int nc = src0->ne[0];
4581+
const int nr = ggml_nrows(src0);
4582+
4583+
// rows per thread
4584+
const int dr = (nr + nth - 1)/nth;
4585+
4586+
// row range for this thread
4587+
const int ir0 = dr*ith;
4588+
const int ir1 = MIN(ir0 + dr, nr);
4589+
4590+
const size_t nb01 = src0->nb[1];
4591+
4592+
const size_t nb1 = dst->nb[1];
4593+
4594+
if (b == 0.0f) {
4595+
for (int i1 = ir0; i1 < ir1; i1++) {
4596+
if (dst->data != src0->data) {
4597+
// src0 is same shape as dst => same indices
4598+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4599+
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(ggml_fp16_t));
4600+
}
4601+
ggml_vec_scale_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*nb1), s);
4602+
}
4603+
} else {
4604+
//TODO: support bias!
4605+
GGML_ABORT("fatal error");
4606+
// for (int i1 = ir0; i1 < ir1; i1++) {
4607+
// ggml_vec_mad1_f16(nc,
4608+
// (ggml_fp16_t *) ((char *) dst->data + i1*nb1),
4609+
// (ggml_fp16_t *) ((char *) src0->data + i1*nb1),
4610+
// s, b);
4611+
// }
4612+
}
4613+
}
4614+
45614615
void ggml_compute_forward_scale(
45624616
const ggml_compute_params * params,
45634617
ggml_tensor * dst) {
@@ -4569,6 +4623,10 @@ void ggml_compute_forward_scale(
45694623
{
45704624
ggml_compute_forward_scale_f32(params, dst);
45714625
} break;
4626+
case GGML_TYPE_F16:
4627+
{
4628+
ggml_compute_forward_scale_f16(params, dst);
4629+
} break;
45724630
default:
45734631
{
45744632
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)