Skip to content

Commit 0dd415f

Browse files
committed
Simplify inplace operation generation and combine mul/add generation
1 parent f0fc822 commit 0dd415f

File tree

10 files changed

+214
-256
lines changed

10 files changed

+214
-256
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -696,23 +696,19 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
696696

697697
std::vector<uint32_t> params = {
698698
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
699+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
700+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
701+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
702+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
703+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
704+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
705+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
706+
(uint32_t) src->ne[0],
707+
(uint32_t) src->ne[1],
708+
(uint32_t) src->ne[2],
709+
(uint32_t) src->ne[3],
710+
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
699711
};
700-
if (!in_place) {
701-
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
702-
}
703-
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
704-
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
705-
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
706-
if (!in_place) {
707-
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
708-
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
709-
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
710-
}
711-
params.push_back((uint32_t) src->ne[0]);
712-
params.push_back((uint32_t) src->ne[1]);
713-
params.push_back((uint32_t) src->ne[2]);
714-
params.push_back((uint32_t) src->ne[3]);
715-
params.push_back(*(uint32_t *) dst->op_params); // epsilon, treated as f32 in the shader
716712

717713
std::vector<wgpu::BindGroupEntry> entries = {
718714
{ .binding = 0,
@@ -1266,10 +1262,10 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
12661262
constants);
12671263
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
12681264
constants);
1269-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
1270-
"add_in_place_f32", constants);
1271-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
1272-
"add_in_place_f16", constants);
1265+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_f32_inplace,
1266+
"add_f32_inplace", constants);
1267+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_f16_inplace,
1268+
"add_f16_inplace", constants);
12731269
}
12741270

12751271
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
@@ -1278,18 +1274,18 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
12781274
constants);
12791275
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
12801276
constants);
1281-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
1282-
"mul_in_place_f32", constants);
1283-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
1284-
"mul_in_place_f16", constants);
1277+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_f32_inplace,
1278+
"mul_f32_inplace", constants);
1279+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_f16_inplace,
1280+
"mul_f16_inplace", constants);
12851281
}
12861282

12871283
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
12881284
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
12891285
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
12901286
constants);
1291-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
1292-
"rms_norm_in_place", constants);
1287+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_inplace,
1288+
"rms_norm_inplace", constants);
12931289
}
12941290

12951291
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {

ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl

Lines changed: 0 additions & 44 deletions
This file was deleted.

ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl

Lines changed: 0 additions & 41 deletions
This file was deleted.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"SHADER_NAME": "add_f32",
6+
"REPLS": {
7+
"TYPE" : "f32",
8+
"OP": "+"
9+
},
10+
"DECLS": ["NOT_INPLACE"]
11+
},
12+
{
13+
"SHADER_NAME": "add_f16",
14+
"REPLS": {
15+
"TYPE" : "f16",
16+
"OP": "+"
17+
},
18+
"DECLS": ["NOT_INPLACE"]
19+
},
20+
{
21+
"SHADER_NAME": "add_f32_inplace",
22+
"REPLS": {
23+
"TYPE" : "f32",
24+
"OP": "+"
25+
},
26+
"DECLS": ["INPLACE"]
27+
},
28+
{
29+
"SHADER_NAME": "add_f16_inplace",
30+
"REPLS": {
31+
"TYPE" : "f16",
32+
"OP": "+"
33+
},
34+
"DECLS": ["INPLACE"]
35+
},
36+
{
37+
"SHADER_NAME": "mul_f32",
38+
"REPLS": {
39+
"TYPE" : "f32",
40+
"OP": "*"
41+
},
42+
"DECLS": ["NOT_INPLACE"]
43+
},
44+
{
45+
"SHADER_NAME": "mul_f16",
46+
"REPLS": {
47+
"TYPE" : "f16",
48+
"OP": "*"
49+
},
50+
"DECLS": ["NOT_INPLACE"]
51+
},
52+
{
53+
"SHADER_NAME": "mul_f32_inplace",
54+
"REPLS": {
55+
"TYPE" : "f32",
56+
"OP": "*"
57+
},
58+
"DECLS": ["INPLACE"]
59+
},
60+
{
61+
"SHADER_NAME": "mul_f16_inplace",
62+
"REPLS": {
63+
"TYPE" : "f16",
64+
"OP": "*"
65+
},
66+
"DECLS": ["INPLACE"]
67+
}
68+
]
69+
70+
#end(VARIANTS)
71+
72+
#define(DECLS)
73+
74+
#decl(NOT_INPLACE)
75+
76+
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
77+
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
78+
}
79+
80+
@group(0) @binding(2)
81+
var<storage, read_write> dst: array<{{TYPE}}>;
82+
83+
@group(0) @binding(3)
84+
var<uniform> params: Params;
85+
86+
#enddecl(NOT_INPLACE)
87+
88+
#decl(INPLACE)
89+
90+
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
91+
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
92+
}
93+
94+
@group(0) @binding(2)
95+
var<uniform> params: Params;
96+
97+
#enddecl(INPLACE)
98+
99+
#end(DECLS)
100+
101+
102+
#define(SHADER)
103+
104+
enable f16;
105+
106+
#include "binary_head.tmpl"
107+
108+
@group(0) @binding(0)
109+
var<storage, read_write> src0: array<{{TYPE}}>;
110+
111+
@group(0) @binding(1)
112+
var<storage, read_write> src1: array<{{TYPE}}>;
113+
114+
DECLS
115+
116+
override wg_size: u32;
117+
@compute @workgroup_size(wg_size)
118+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
119+
if (gid.x < params.ne) {
120+
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
121+
}
122+
}
123+
124+
#end(SHADER)

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,17 @@ def generate_variants(fname, input_dir, output_dir, outfile):
8989
decls_code += decls_map[key] + "\n\n"
9090

9191
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
92-
final_shader = replace_placeholders(final_shader, variant["REPLS"])
92+
if "REPLS" in variant:
93+
final_shader = replace_placeholders(final_shader, variant["REPLS"])
9394
final_shader = expand_includes(final_shader, input_dir)
9495

95-
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
96-
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
96+
if "SHADER_NAME" in variant:
97+
output_name = variant["SHADER_NAME"]
9798
elif "SHADER_SUFFIX" in variant:
9899
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
99-
elif "TYPE" in variant["REPLS"]:
100+
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
101+
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
102+
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
100103
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
101104
else:
102105
output_name = shader_base_name

ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl

Lines changed: 0 additions & 44 deletions
This file was deleted.

ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)