Skip to content

Commit 50f88fc

Browse files
committed
ggml : add ggml_scale_bias
1 parent f667f1e commit 50f88fc

File tree

6 files changed

+59
-17
lines changed

6 files changed

+59
-17
lines changed

ggml/include/ggml.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,19 @@ extern "C" {
11851185
struct ggml_tensor * a,
11861186
float s);
11871187

1188+
// x = s * a + b
1189+
GGML_API struct ggml_tensor * ggml_scale_bias(
1190+
struct ggml_context * ctx,
1191+
struct ggml_tensor * a,
1192+
float s,
1193+
float b);
1194+
1195+
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
1196+
struct ggml_context * ctx,
1197+
struct ggml_tensor * a,
1198+
float s,
1199+
float b);
1200+
11881201
// b -> view(a,offset,nb1,nb2,3), return modified a
11891202
GGML_API struct ggml_tensor * ggml_set(
11901203
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3937,9 +3937,11 @@ static void ggml_compute_forward_scale_f32(
39373937
GGML_ASSERT(ggml_is_contiguous(dst));
39383938
GGML_ASSERT(ggml_are_same_shape(src0, dst));
39393939

3940-
// scale factor
3941-
float v;
3942-
memcpy(&v, dst->op_params, sizeof(float));
3940+
float s; // scale factor
3941+
float b; // bias
3942+
3943+
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
3944+
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
39433945

39443946
const int ith = params->ith;
39453947
const int nth = params->nth;
@@ -3963,7 +3965,10 @@ static void ggml_compute_forward_scale_f32(
39633965
// src0 is same shape as dst => same indices
39643966
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
39653967
}
3966-
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
3968+
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
3969+
if (b != 0.0f) {
3970+
ggml_vec_acc1_f32(nc, (float *) ((char *) dst->data + i1*nb1), b);
3971+
}
39673972
}
39683973
}
39693974

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,8 +2189,8 @@ static bool ggml_metal_encode_node(
21892189
{
21902190
GGML_ASSERT(ggml_is_contiguous(src0));
21912191

2192-
float scale;
2193-
memcpy(&scale, dst->op_params, sizeof(scale));
2192+
float scale = ((const float *)(dst->op_params))[0];
2193+
float bias = ((const float *)(dst->op_params))[1];
21942194

21952195
int64_t n = ggml_nelements(dst);
21962196

@@ -2207,6 +2207,7 @@ static bool ggml_metal_encode_node(
22072207
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
22082208
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
22092209
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2210+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
22102211

22112212
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
22122213
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,16 +810,18 @@ kernel void kernel_scale(
810810
device const float * src0,
811811
device float * dst,
812812
constant float & scale,
813+
constant float & bias,
813814
uint tpig[[thread_position_in_grid]]) {
814-
dst[tpig] = src0[tpig] * scale;
815+
dst[tpig] = src0[tpig] * scale + bias;
815816
}
816817

817818
kernel void kernel_scale_4(
818819
device const float4 * src0,
819820
device float4 * dst,
820821
constant float & scale,
822+
constant float & bias,
821823
uint tpig[[thread_position_in_grid]]) {
822-
dst[tpig] = src0[tpig] * scale;
824+
dst[tpig] = src0[tpig] * scale + bias;
823825
}
824826

825827
kernel void kernel_clamp(

ggml/src/ggml.c

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,12 +2858,14 @@ static struct ggml_tensor * ggml_scale_impl(
28582858
struct ggml_context * ctx,
28592859
struct ggml_tensor * a,
28602860
float s,
2861+
float b,
28612862
bool inplace) {
28622863
GGML_ASSERT(ggml_is_padded_1d(a));
28632864

28642865
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
28652866

2866-
ggml_set_op_params(result, &s, sizeof(s));
2867+
float params[2] = { s, b };
2868+
ggml_set_op_params(result, &params, sizeof(params));
28672869

28682870
result->op = GGML_OP_SCALE;
28692871
result->src[0] = a;
@@ -2875,14 +2877,30 @@ struct ggml_tensor * ggml_scale(
28752877
struct ggml_context * ctx,
28762878
struct ggml_tensor * a,
28772879
float s) {
2878-
return ggml_scale_impl(ctx, a, s, false);
2880+
return ggml_scale_impl(ctx, a, s, 0.0, false);
28792881
}
28802882

28812883
struct ggml_tensor * ggml_scale_inplace(
28822884
struct ggml_context * ctx,
28832885
struct ggml_tensor * a,
28842886
float s) {
2885-
return ggml_scale_impl(ctx, a, s, true);
2887+
return ggml_scale_impl(ctx, a, s, 0.0, true);
2888+
}
2889+
2890+
struct ggml_tensor * ggml_scale_bias(
2891+
struct ggml_context * ctx,
2892+
struct ggml_tensor * a,
2893+
float s,
2894+
float b) {
2895+
return ggml_scale_impl(ctx, a, s, b, false);
2896+
}
2897+
2898+
struct ggml_tensor * ggml_scale_bias_inplace(
2899+
struct ggml_context * ctx,
2900+
struct ggml_tensor * a,
2901+
float s,
2902+
float b) {
2903+
return ggml_scale_impl(ctx, a, s, b, true);
28862904
}
28872905

28882906
// ggml_set
@@ -5472,7 +5490,7 @@ static void ggml_compute_backward(
54725490
} break;
54735491
case GGML_OP_MEAN: {
54745492
if (src0_needs_grads) {
5475-
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5493+
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
54765494
}
54775495
} break;
54785496
case GGML_OP_REPEAT: {
@@ -5549,7 +5567,7 @@ static void ggml_compute_backward(
55495567
if (src0_needs_grads) {
55505568
float s;
55515569
memcpy(&s, tensor->op_params, sizeof(float));
5552-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
5570+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
55535571
}
55545572
} break;
55555573
case GGML_OP_SET: {

tests/test-backend-ops.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,22 +1655,24 @@ struct test_scale : public test_case {
16551655
const ggml_type type;
16561656
const std::array<int64_t, 4> ne;
16571657
float scale;
1658+
float bias;
16581659

16591660
std::string vars() override {
1660-
return VARS_TO_STR3(type, ne, scale);
1661+
return VARS_TO_STR4(type, ne, scale, bias);
16611662
}
16621663

16631664
test_scale(ggml_type type = GGML_TYPE_F32,
16641665
std::array<int64_t, 4> ne = {10, 10, 10, 10},
1665-
float scale = 2.0f)
1666-
: type(type), ne(ne), scale(scale) {}
1666+
float scale = 2.0f,
1667+
float bias = 0.0f)
1668+
: type(type), ne(ne), scale(scale), bias(bias) {}
16671669

16681670
ggml_tensor * build_graph(ggml_context * ctx) override {
16691671
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
16701672
ggml_set_param(a);
16711673
ggml_set_name(a, "a");
16721674

1673-
ggml_tensor * out = ggml_scale(ctx, a, scale);
1675+
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
16741676
ggml_set_name(out, "out");
16751677

16761678
return out;
@@ -4209,6 +4211,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42094211

42104212
test_cases.emplace_back(new test_add1());
42114213
test_cases.emplace_back(new test_scale());
4214+
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
42124215
test_cases.emplace_back(new test_silu_back());
42134216

42144217
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {

0 commit comments

Comments
 (0)