Skip to content

Commit 4561784

Browse files
committed
Vectorize f32 and change default workgroup size
1 parent fc91520 commit 4561784

File tree

3 files changed

+68
-11
lines changed

3 files changed

+68
-11
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ struct webgpu_context_struct {
129129
wgpu::ComputePipeline mul_mat_pipeline[30][2];
130130
wgpu::ComputePipeline set_rows_pipeline;
131131
wgpu::ComputePipeline get_rows_pipeline[30];
132+
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
132133
wgpu::ComputePipeline cpy_pipeline;
133134
wgpu::ComputePipeline add_pipeline[2];
134135
wgpu::ComputePipeline add_ip_pipeline[2];
@@ -595,8 +596,11 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
595596
size_t max_wg_size = ctx->max_wg_size_x;
596597
uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
597598

598-
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->get_rows_pipeline[src->type], params, entries, wg_x,
599-
ggml_op_name(dst->op));
599+
wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type];
600+
if (src->type == GGML_TYPE_F32 && dst->ne[0] < 4) {
601+
pipeline = ctx->get_rows_f32_no_vec_pipeline;
602+
}
603+
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
600604
}
601605

602606
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -1117,7 +1121,9 @@ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
11171121

11181122
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
11191123
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1120-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32,
1124+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
1125+
"get_rows_f32_vec", constants);
1126+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
11211127
"get_rows_f32", constants);
11221128
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
11231129
"get_rows_f16", constants);
@@ -1423,7 +1429,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
14231429
GGML_ASSERT(ctx->adapter != nullptr);
14241430

14251431
ctx->adapter.GetLimits(&ctx->limits);
1426-
ctx->max_wg_size_x = 256; // default value
1432+
ctx->max_wg_size_x = 288; // default value
14271433

14281434
wgpu::AdapterInfo info{};
14291435
ctx->adapter.GetInfo(&info);

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def generate_variants(fname, input_dir, output_dir, outfile):
9292

9393
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
9494
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
95+
elif "TYPE_SUFFIX" in variant["REPLS"]:
96+
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
9597
elif "TYPE" in variant["REPLS"]:
9698
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
9799
else:

ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,187 @@
11
#define(VARIANTS)
22

33
[
4+
{
5+
"REPLS": {
6+
"TYPE" : "vec4<f32>",
7+
"TYPE_SUFFIX": "f32_vec",
8+
"DST_TYPE": "vec4<f32>",
9+
"BLOCK_SIZE": 4
10+
},
11+
"DECLS": ["F32_VEC"]
12+
},
413
{
514
"REPLS": {
615
"TYPE" : "f32",
16+
"DST_TYPE": "f32",
717
"BLOCK_SIZE": 1
818
},
9-
"DECLS": ["FLOAT"]
19+
"DECLS": ["F32"]
1020
},
1121
{
1222
"REPLS": {
1323
"TYPE" : "f16",
24+
"DST_TYPE": "f32",
1425
"BLOCK_SIZE": 1
1526
},
16-
"DECLS": ["FLOAT"]
27+
"DECLS": ["F16"]
1728
},
1829
{
1930
"REPLS": {
2031
"TYPE" : "i32",
32+
"DST_TYPE": "i32",
2133
"BLOCK_SIZE": 1
2234
},
23-
"DECLS": ["FLOAT"]
35+
"DECLS": ["I32"]
2436
},
2537
{
2638
"REPLS": {
2739
"TYPE" : "q4_0",
40+
"DST_TYPE": "f32",
2841
"BLOCK_SIZE": 32
2942
},
3043
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
3144
},
3245
{
3346
"REPLS": {
3447
"TYPE" : "q4_1",
48+
"DST_TYPE": "f32",
3549
"BLOCK_SIZE": 32
3650
},
3751
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
3852
},
3953
{
4054
"REPLS": {
4155
"TYPE" : "q5_0",
56+
"DST_TYPE": "f32",
4257
"BLOCK_SIZE": 32
4358
},
4459
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
4560
},
4661
{
4762
"REPLS": {
4863
"TYPE" : "q5_1",
64+
"DST_TYPE": "f32",
4965
"BLOCK_SIZE": 32
5066
},
5167
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
5268
},
5369
{
5470
"REPLS": {
5571
"TYPE" : "q8_0",
72+
"DST_TYPE": "f32",
5673
"BLOCK_SIZE": 32
5774
},
5875
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
5976
},
6077
{
6178
"REPLS": {
6279
"TYPE" : "q2_k",
80+
"DST_TYPE": "f32",
6381
"BLOCK_SIZE": 256
6482
},
6583
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
6684
},
6785
{
6886
"REPLS": {
6987
"TYPE" : "q3_k",
88+
"DST_TYPE": "f32",
7089
"BLOCK_SIZE": 256
7190
},
7291
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
7392
},
7493
{
7594
"REPLS": {
7695
"TYPE" : "q4_k",
96+
"DST_TYPE": "f32",
7797
"BLOCK_SIZE": 256
7898
},
7999
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
80100
},
81101
{
82102
"REPLS": {
83103
"TYPE" : "q5_k",
104+
"DST_TYPE": "f32",
84105
"BLOCK_SIZE": 256
85106
},
86107
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
87108
},
88109
{
89110
"REPLS": {
90111
"TYPE" : "q6_k",
112+
"DST_TYPE": "f32",
91113
"BLOCK_SIZE": 256
92114
},
93115
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
94116
},
95117
{
96118
"REPLS": {
97119
"TYPE" : "iq2_xxs",
120+
"DST_TYPE": "f32",
98121
"BLOCK_SIZE": 256
99122
},
100123
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
101124
},
102125
{
103126
"REPLS": {
104127
"TYPE" : "iq2_xs",
128+
"DST_TYPE": "f32",
105129
"BLOCK_SIZE": 256
106130
},
107131
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
108132
},
109133
{
110134
"REPLS": {
111135
"TYPE": "iq2_s",
136+
"DST_TYPE": "f32",
112137
"BLOCK_SIZE": 256
113138
},
114139
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
115140
},
116141
{
117142
"REPLS": {
118143
"TYPE": "iq3_xxs",
144+
"DST_TYPE": "f32",
119145
"BLOCK_SIZE": 256
120146
},
121147
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
122148
},
123149
{
124150
"REPLS": {
125151
"TYPE": "iq3_s",
152+
"DST_TYPE": "f32",
126153
"BLOCK_SIZE": 256
127154
},
128155
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
129156
},
130157
{
131158
"REPLS": {
132159
"TYPE": "iq1_s",
160+
"DST_TYPE": "f32",
133161
"BLOCK_SIZE": 256
134162
},
135163
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
136164
},
137165
{
138166
"REPLS": {
139167
"TYPE": "iq1_m",
168+
"DST_TYPE": "f32",
140169
"BLOCK_SIZE": 256
141170
},
142171
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
143172
},
144173
{
145174
"REPLS": {
146175
"TYPE": "iq4_nl",
176+
"DST_TYPE": "f32",
147177
"BLOCK_SIZE": 32,
148178
},
149179
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
150180
},
151181
{
152182
"REPLS": {
153183
"TYPE": "iq4_xs",
184+
"DST_TYPE": "f32",
154185
"BLOCK_SIZE": 256,
155186
},
156187
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
@@ -161,11 +192,29 @@
161192

162193
#define(DECLS)
163194

164-
#decl(FLOAT)
195+
#decl(F32_VEC)
196+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
197+
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
198+
}
199+
#enddecl(F32_VEC)
200+
201+
#decl(F32)
202+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
203+
dst[dst_base + offset] = src[src_base + offset];
204+
}
205+
#enddecl(F32)
206+
207+
#decl(F16)
165208
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
166209
dst[dst_base + offset] = f32(src[src_base + offset]);
167210
}
168-
#enddecl(FLOAT)
211+
#enddecl(F16)
212+
213+
#decl(I32)
214+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
215+
dst[dst_base + offset] = src[src_base + offset];
216+
}
217+
#enddecl(I32)
169218

170219
#decl(Q4_0)
171220
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
@@ -759,7 +808,7 @@ var<storage, read_write> src: array<{{TYPE}}>;
759808
var<storage, read_write> idx: array<i32>;
760809

761810
@group(0) @binding(2)
762-
var<storage, read_write> dst: array<f32>;
811+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
763812

764813
struct Params {
765814
offset_src: u32, // in elements
@@ -822,4 +871,4 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
822871
}
823872
}
824873

825-
#end(SHADER)
874+
#end(SHADER)

0 commit comments

Comments
 (0)