Skip to content

Commit a7c9d33

Browse files
committed
implement scale op
1 parent 415b2d5 commit a7c9d33

File tree

3 files changed

+160
-7
lines changed

3 files changed

+160
-7
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ struct webgpu_context_struct {
138138
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
139139
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
140140
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
141+
wgpu::ComputePipeline scale_pipeline[2]; // inplace
141142

142143
size_t memset_bytes_per_thread;
143144

@@ -840,9 +841,9 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
840841
(uint32_t) dst->ne[0],
841842
(uint32_t) dst->ne[1],
842843
(uint32_t) dst->ne[2],
843-
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
844-
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
845-
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
844+
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
845+
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
846+
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
846847
};
847848

848849
std::vector<wgpu::BindGroupEntry> entries = {
@@ -870,6 +871,45 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
870871
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
871872
}
872873

874+
static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
875+
int inplace = ggml_webgpu_tensor_equal(src, dst);
876+
877+
std::vector<uint32_t> params = {
878+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
879+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
880+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
881+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
882+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
883+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
884+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
885+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
886+
(uint32_t) ggml_nelements(dst),
887+
(uint32_t) src->ne[0],
888+
(uint32_t) src->ne[1],
889+
(uint32_t) src->ne[2],
890+
*(uint32_t *) dst->op_params, // scale
891+
*(uint32_t *) &dst->op_params[1] // bias
892+
};
893+
894+
std::vector<wgpu::BindGroupEntry> entries = {
895+
{ .binding = 0,
896+
.buffer = ggml_webgpu_tensor_buf(src),
897+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
898+
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
899+
};
900+
if (!inplace) {
901+
entries.push_back({ .binding = 1,
902+
.buffer = ggml_webgpu_tensor_buf(dst),
903+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
904+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
905+
}
906+
907+
size_t max_wg_size = ctx->max_wg_size_x;
908+
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
909+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x,
910+
ggml_op_name(dst->op));
911+
}
912+
873913
// Returns true if node has enqueued work into the queue, false otherwise
874914
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
875915
if (ggml_is_empty(node)) {
@@ -934,6 +974,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
934974
case GGML_OP_GLU:
935975
ggml_webgpu_glu(ctx, src0, src1, node);
936976
break;
977+
case GGML_OP_SCALE:
978+
ggml_webgpu_scale(ctx, src0, node);
979+
break;
937980
default:
938981
return false;
939982
}
@@ -1449,7 +1492,14 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
14491492
wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
14501493
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1],
14511494
wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
1495+
}
14521496

1497+
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
1498+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1499+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
1500+
constants);
1501+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
1502+
"scale_f32_inplace", constants);
14531503
}
14541504

14551505
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -1628,6 +1678,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
16281678
break;
16291679
}
16301680
break;
1681+
case GGML_OP_SCALE:
1682+
supports_op = op->type == GGML_TYPE_F32;
1683+
break;
16311684
default:
16321685
break;
16331686
}
@@ -1758,6 +1811,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
17581811
ggml_webgpu_init_rms_norm_pipeline(ctx);
17591812
ggml_webgpu_init_rope_pipeline(ctx);
17601813
ggml_webgpu_init_glu_pipeline(ctx);
1814+
ggml_webgpu_init_scale_pipeline(ctx);
17611815

17621816
#ifdef GGML_WEBGPU_DEBUG
17631817
// Initialize debug buffers
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"SHADER_NAME": "scale_f32",
6+
"DECLS": ["NOT_INPLACE"]
7+
},
8+
{
9+
"SHADER_NAME": "scale_f32_inplace",
10+
"DECLS": ["INPLACE"]
11+
}
12+
]
13+
14+
#end(VARIANTS)
15+
16+
#define(DECLS)
17+
18+
#decl(NOT_INPLACE)
19+
@group(0) @binding(1)
20+
var<storage, read_write> dst: array<f32>;
21+
22+
@group(0) @binding(2)
23+
var<uniform> params: Params;
24+
25+
fn store_scale(val: f32, offset: u32) {
26+
dst[offset] = val;
27+
}
28+
#enddecl(NOT_INPLACE)
29+
30+
#decl(INPLACE)
31+
@group(0) @binding(1)
32+
var<uniform> params: Params;
33+
34+
fn store_scale(val: f32, offset: u32) {
35+
src[offset] = val;
36+
}
37+
#enddecl(INPLACE)
38+
39+
#end(DECLS)
40+
41+
#define(SHADER)
42+
43+
struct Params {
44+
offset_src: u32,
45+
offset_dst: u32,
46+
47+
// Strides (in elements)
48+
stride_src1: u32,
49+
stride_src2: u32,
50+
stride_src3: u32,
51+
52+
stride_dst1: u32,
53+
stride_dst2: u32,
54+
stride_dst3: u32,
55+
56+
ne: u32,
57+
ne0: u32,
58+
ne1: u32,
59+
ne2: u32,
60+
61+
scale: f32,
62+
bias: f32
63+
};
64+
65+
@group(0) @binding(0)
66+
var<storage, read_write> src: array<f32>;
67+
68+
DECLS
69+
70+
override wg_size: u32;
71+
@compute @workgroup_size(wg_size)
72+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
73+
if (gid.x >= params.ne) {
74+
return;
75+
}
76+
77+
var i = gid.x;
78+
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
79+
i = i % (params.ne2 * params.ne1 * params.ne0);
80+
let i2 = i / (params.ne1 * params.ne0);
81+
i = i % (params.ne1 * params.ne0);
82+
let i1 = i / params.ne0;
83+
let i0 = i % params.ne0;
84+
85+
let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
86+
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
87+
88+
store_scale(src[i_src] * params.scale + params.bias, i_dst);
89+
}
90+
#end(SHADER)

tests/test-backend-ops.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,23 +2677,30 @@ struct test_scale : public test_case {
26772677
const std::array<int64_t, 4> ne;
26782678
float scale;
26792679
float bias;
2680+
bool inplace;
26802681

26812682
std::string vars() override {
2682-
return VARS_TO_STR4(type, ne, scale, bias);
2683+
return VARS_TO_STR5(type, ne, scale, bias, inplace);
26832684
}
26842685

26852686
test_scale(ggml_type type = GGML_TYPE_F32,
26862687
std::array<int64_t, 4> ne = {10, 10, 10, 10},
26872688
float scale = 2.0f,
2688-
float bias = 0.0f)
2689-
: type(type), ne(ne), scale(scale), bias(bias) {}
2689+
float bias = 0.0f,
2690+
bool inplace = false)
2691+
: type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}
26902692

26912693
ggml_tensor * build_graph(ggml_context * ctx) override {
26922694
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
26932695
ggml_set_param(a);
26942696
ggml_set_name(a, "a");
26952697

2696-
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
2698+
ggml_tensor * out;
2699+
if (inplace) {
2700+
out = ggml_scale_bias_inplace(ctx, a, scale, bias);
2701+
} else {
2702+
out = ggml_scale_bias(ctx, a, scale, bias);
2703+
}
26972704
ggml_set_name(out, "out");
26982705

26992706
return out;
@@ -6110,6 +6117,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
61106117
test_cases.emplace_back(new test_add1());
61116118
test_cases.emplace_back(new test_scale());
61126119
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
6120+
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
6121+
61136122
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
61146123
test_cases.emplace_back(new test_silu_back());
61156124

0 commit comments

Comments
 (0)