Skip to content

Commit d304f45

Browse files
reeselevineNeha Abbas
andauthored
GGML WebGPU: Support for ADD, MUL, RMS_NORM, GET_ROWS operators (ggml-org#16018)
* Add paramater buffer pool, batching of submissions, refactor command building/submission * Add header for linux builds * Free staged parameter buffers at once * Format with clang-format * Fix thread-safe implementation * Use device implicit synchronization * Update workflow to use custom release * Remove testing branch workflow * some f32 tests passing * Disable set_rows until it's implemented * f32 add all tests passing * Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments * Add templated addition, clean up code * Get addition and multiplication working * Implement rms_norm * Add get_rows implementation * Add new get_rows files * Refactor use of wg size entry * Fix compilation * Try manually unrolled q4_0 quant * Revert "Try manually unrolled q4_0 quant" This reverts commit 77f8b96. * Move to constant max wg size * Check for tensor size in supports_op * Vectorize f32 and change default workgroup size * Move f32 get_rows from < 4 to % 4 != 0 * fix linter errors * Add in-place tests --------- Co-authored-by: Neha Abbas <[email protected]>
1 parent 0320ac5 commit d304f45

14 files changed

+2673
-1141
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 479 additions & 226 deletions
Large diffs are not rendered by default.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"REPLS": {
6+
"TYPE" : "f32",
7+
}
8+
},
9+
{
10+
"REPLS": {
11+
"TYPE" : "f16",
12+
}
13+
}
14+
]
15+
16+
#end(VARIANTS)
17+
18+
#define(SHADER)
19+
20+
enable f16;
21+
22+
#include "binary_head.tmpl"
23+
24+
@group(0) @binding(0)
25+
var<storage, read_write> src0: array<{{TYPE}}>;
26+
27+
@group(0) @binding(1)
28+
var<storage, read_write> src1: array<{{TYPE}}>;
29+
30+
@group(0) @binding(2)
31+
var<storage, read_write> dst: array<{{TYPE}}>;
32+
33+
@group(0) @binding(3)
34+
var<uniform> params: Params;
35+
36+
override wg_size: u32;
37+
@compute @workgroup_size(wg_size)
38+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
39+
if (gid.x < params.ne) {
40+
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
41+
}
42+
}
43+
44+
#end(SHADER)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"REPLS": {
6+
"TYPE" : "f32",
7+
}
8+
},
9+
{
10+
"REPLS": {
11+
"TYPE" : "f16",
12+
}
13+
}
14+
]
15+
16+
#end(VARIANTS)
17+
18+
#define(SHADER)
19+
20+
enable f16;
21+
22+
#include "binary_head.tmpl"
23+
24+
@group(0) @binding(0)
25+
var<storage, read_write> src0: array<{{TYPE}}>;
26+
27+
@group(0) @binding(1)
28+
var<storage, read_write> src1: array<{{TYPE}}>;
29+
30+
@group(0) @binding(2)
31+
var<uniform> params: Params;
32+
33+
override wg_size: u32;
34+
@compute @workgroup_size(wg_size)
35+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
36+
if (gid.x < params.ne) {
37+
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
38+
}
39+
}
40+
41+
#end(SHADER)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
struct Params {
2+
ne: u32,
3+
4+
// offsets in elements
5+
offset_src0: u32,
6+
offset_src1: u32,
7+
offset_dst: u32,
8+
9+
stride_src1_0: u32,
10+
stride_src1_1: u32,
11+
stride_src1_2: u32,
12+
stride_src1_3: u32,
13+
14+
a_ne0: u32,
15+
a_ne1: u32,
16+
a_ne2: u32,
17+
18+
b_ne0: u32,
19+
b_ne1: u32,
20+
b_ne2: u32,
21+
b_ne3: u32,
22+
};
23+
24+
fn src1_index(_i: u32) -> u32 {
25+
var i = _i;
26+
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
27+
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
28+
let a_i2 = i / (params.a_ne1 * params.a_ne0);
29+
i = i % (params.a_ne1 * params.a_ne0);
30+
let a_i1 = i / params.a_ne0;
31+
let a_i0 = i % params.a_ne0;
32+
33+
// handle repetition of b
34+
// index loops back to the beginning and repeats after elements are exhausted = modulo
35+
let b_i0 = a_i0 % params.b_ne0;
36+
let b_i1 = a_i1 % params.b_ne1;
37+
let b_i2 = a_i2 % params.b_ne2;
38+
let b_i3 = a_i3 % params.b_ne3;
39+
40+
// compute index for position in b's flat array
41+
return b_i0 * params.stride_src1_0 +
42+
b_i1 * params.stride_src1_1 +
43+
b_i2 * params.stride_src1_2 +
44+
b_i3 * params.stride_src1_3;
45+
}

0 commit comments

Comments
 (0)