Skip to content

Commit 8d78cd2

Browse files
ggml webgpu: support for rope,div,sub,glu,scale,cont operators (ggml-org#16187)
* Work on rope * Simplify inplace operation generation and combine mul/add generation * Work on rope variants * implement neox rope * rope complete * Add sub,div,glu operators * implement scale op * Update cpy shader to handle cont/more types * formatting * Update test vars printing for rope,rms_norm * Avoid ROPE hardcoded constants * Add TODO to change ROPE constants to enum Co-authored-by: Georgi Gerganov <[email protected]> * fix TODO comment --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent d1c84a6 commit 8d78cd2

17 files changed

+1534
-397
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@
237237
#define GGML_EXIT_SUCCESS 0
238238
#define GGML_EXIT_ABORTED 1
239239

240+
// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
241+
#define GGML_ROPE_TYPE_NORMAL 0
240242
#define GGML_ROPE_TYPE_NEOX 2
241243
#define GGML_ROPE_TYPE_MROPE 8
242244
#define GGML_ROPE_TYPE_VISION 24

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 414 additions & 74 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl

Lines changed: 0 additions & 44 deletions
This file was deleted.

ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl

Lines changed: 0 additions & 41 deletions
This file was deleted.
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"SHADER_NAME": "add_f32",
6+
"REPLS": {
7+
"TYPE" : "f32",
8+
"OP": "+"
9+
},
10+
"DECLS": ["NOT_INPLACE"]
11+
},
12+
{
13+
"SHADER_NAME": "add_f16",
14+
"REPLS": {
15+
"TYPE" : "f16",
16+
"OP": "+"
17+
},
18+
"DECLS": ["NOT_INPLACE"]
19+
},
20+
{
21+
"SHADER_NAME": "add_f32_inplace",
22+
"REPLS": {
23+
"TYPE" : "f32",
24+
"OP": "+"
25+
},
26+
"DECLS": ["INPLACE"]
27+
},
28+
{
29+
"SHADER_NAME": "add_f16_inplace",
30+
"REPLS": {
31+
"TYPE" : "f16",
32+
"OP": "+"
33+
},
34+
"DECLS": ["INPLACE"]
35+
},
36+
{
37+
"SHADER_NAME": "mul_f32",
38+
"REPLS": {
39+
"TYPE" : "f32",
40+
"OP": "*"
41+
},
42+
"DECLS": ["NOT_INPLACE"]
43+
},
44+
{
45+
"SHADER_NAME": "mul_f16",
46+
"REPLS": {
47+
"TYPE" : "f16",
48+
"OP": "*"
49+
},
50+
"DECLS": ["NOT_INPLACE"]
51+
},
52+
{
53+
"SHADER_NAME": "mul_f32_inplace",
54+
"REPLS": {
55+
"TYPE" : "f32",
56+
"OP": "*"
57+
},
58+
"DECLS": ["INPLACE"]
59+
},
60+
{
61+
"SHADER_NAME": "mul_f16_inplace",
62+
"REPLS": {
63+
"TYPE" : "f16",
64+
"OP": "*"
65+
},
66+
"DECLS": ["INPLACE"]
67+
},
68+
{
69+
"SHADER_NAME": "sub_f32",
70+
"REPLS": {
71+
"TYPE" : "f32",
72+
"OP": "-"
73+
},
74+
"DECLS": ["NOT_INPLACE"]
75+
},
76+
{
77+
"SHADER_NAME": "sub_f16",
78+
"REPLS": {
79+
"TYPE" : "f16",
80+
"OP": "-"
81+
},
82+
"DECLS": ["NOT_INPLACE"]
83+
},
84+
{
85+
"SHADER_NAME": "sub_f32_inplace",
86+
"REPLS": {
87+
"TYPE" : "f32",
88+
"OP": "-"
89+
},
90+
"DECLS": ["INPLACE"]
91+
},
92+
{
93+
"SHADER_NAME": "sub_f16_inplace",
94+
"REPLS": {
95+
"TYPE" : "f16",
96+
"OP": "-"
97+
},
98+
"DECLS": ["INPLACE"]
99+
},
100+
{
101+
"SHADER_NAME": "div_f32",
102+
"REPLS": {
103+
"TYPE" : "f32",
104+
"OP": "/"
105+
},
106+
"DECLS": ["NOT_INPLACE"]
107+
},
108+
{
109+
"SHADER_NAME": "div_f16",
110+
"REPLS": {
111+
"TYPE" : "f16",
112+
"OP": "/"
113+
},
114+
"DECLS": ["NOT_INPLACE"]
115+
},
116+
{
117+
"SHADER_NAME": "div_f32_inplace",
118+
"REPLS": {
119+
"TYPE" : "f32",
120+
"OP": "/"
121+
},
122+
"DECLS": ["INPLACE"]
123+
},
124+
{
125+
"SHADER_NAME": "div_f16_inplace",
126+
"REPLS": {
127+
"TYPE" : "f16",
128+
"OP": "/"
129+
},
130+
"DECLS": ["INPLACE"]
131+
}
132+
]
133+
134+
#end(VARIANTS)
135+
136+
#define(DECLS)
137+
138+
#decl(NOT_INPLACE)
139+
140+
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
141+
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
142+
}
143+
144+
@group(0) @binding(2)
145+
var<storage, read_write> dst: array<{{TYPE}}>;
146+
147+
@group(0) @binding(3)
148+
var<uniform> params: Params;
149+
150+
#enddecl(NOT_INPLACE)
151+
152+
#decl(INPLACE)
153+
154+
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
155+
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
156+
}
157+
158+
@group(0) @binding(2)
159+
var<uniform> params: Params;
160+
161+
#enddecl(INPLACE)
162+
163+
#end(DECLS)
164+
165+
166+
#define(SHADER)
167+
168+
enable f16;
169+
170+
#include "binary_head.tmpl"
171+
172+
@group(0) @binding(0)
173+
var<storage, read_write> src0: array<{{TYPE}}>;
174+
175+
@group(0) @binding(1)
176+
var<storage, read_write> src1: array<{{TYPE}}>;
177+
178+
DECLS
179+
180+
override wg_size: u32;
181+
@compute @workgroup_size(wg_size)
182+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
183+
if (gid.x < params.ne) {
184+
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
185+
}
186+
}
187+
188+
#end(SHADER)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"REPLS": {
6+
"SRC_TYPE": "f32",
7+
"DST_TYPE": "f32"
8+
}
9+
},
10+
{
11+
"REPLS": {
12+
"SRC_TYPE": "f32",
13+
"DST_TYPE": "f16"
14+
}
15+
},
16+
{
17+
"REPLS": {
18+
"SRC_TYPE": "f16",
19+
"DST_TYPE": "f16"
20+
}
21+
},
22+
{
23+
"REPLS": {
24+
"SRC_TYPE": "f16",
25+
"DST_TYPE": "f32"
26+
}
27+
}
28+
]
29+
30+
#end(VARIANTS)
31+
32+
#define(SHADER)
33+
enable f16;
34+
35+
@group(0) @binding(0)
36+
var<storage, read_write> src: array<{{SRC_TYPE}}>;
37+
38+
@group(0) @binding(1)
39+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
40+
41+
struct Params {
42+
ne: u32, // total number of elements
43+
offset_src: u32, // in elements
44+
offset_dst: u32, // in elements
45+
46+
// Strides (in elements) — may be permuted
47+
stride_src0: u32,
48+
stride_src1: u32,
49+
stride_src2: u32,
50+
stride_src3: u32,
51+
52+
stride_dst0: u32,
53+
stride_dst1: u32,
54+
stride_dst2: u32,
55+
stride_dst3: u32,
56+
57+
// Logical shapes
58+
src_ne0: u32,
59+
src_ne1: u32,
60+
src_ne2: u32,
61+
62+
dst_ne0: u32,
63+
dst_ne1: u32,
64+
dst_ne2: u32
65+
};
66+
67+
@group(0) @binding(2)
68+
var<uniform> params: Params;
69+
70+
override wg_size: u32;
71+
@compute @workgroup_size(wg_size)
72+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
73+
if (gid.x >= params.ne) {
74+
return;
75+
}
76+
77+
var i = gid.x;
78+
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
79+
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
80+
let i2 = i / (params.src_ne1 * params.src_ne0);
81+
i = i % (params.src_ne1 * params.src_ne0);
82+
let i1 = i / params.src_ne0;
83+
let i0 = i % params.src_ne0;
84+
85+
var j = gid.x;
86+
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
87+
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
88+
let j2 = j / (params.dst_ne1 * params.dst_ne0);
89+
j = j % (params.dst_ne1 * params.dst_ne0);
90+
let j1 = j / params.dst_ne0;
91+
let j0 = j % params.dst_ne0;
92+
93+
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
94+
i2 * params.stride_src2 + i3 * params.stride_src3;
95+
96+
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
97+
j2 * params.stride_dst2 + j3 * params.stride_dst3;
98+
99+
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
100+
}
101+
#end(SHADER)

0 commit comments

Comments
 (0)