Skip to content

Commit f368820

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 7ffce14 + e15cd06 commit f368820

File tree

8 files changed

+20043
-745
lines changed

8 files changed

+20043
-745
lines changed

docs/ops.md

Lines changed: 108 additions & 108 deletions
Large diffs are not rendered by default.

docs/ops/WebGPU.csv

Lines changed: 18741 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 100 additions & 259 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,12 @@ layout(push_constant) uniform parameter {
3232
uint32_t Cin;
3333
uint32_t N;
3434

35-
// Tensor spatial sizes: kernel, input, output
36-
uint32_t KW;
37-
uint32_t KH;
35+
// Tensor spatial sizes: input, output
3836
uint32_t W;
3937
uint32_t H;
4038
uint32_t OW;
4139
uint32_t OH;
4240

43-
// Parameters: stride, padding, dilation - 0=y, 1=x
44-
uint32_t s0;
45-
uint32_t s1;
46-
uint32_t p0;
47-
uint32_t p1;
48-
uint32_t d0;
49-
uint32_t d1;
50-
5141
// Strides in elements
5242
uint32_t nb01;
5343
uint32_t nb02;
@@ -77,13 +67,14 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
7767
layout(constant_id = 4) const uint TS_K = 8;
7868
layout(constant_id = 5) const uint use_collectives = 1;
7969
layout(constant_id = 6) const uint SHMEM_PAD = 4;
80-
70+
// Stride, padding, dilation
8171
layout(constant_id = 7) const uint s0 = 1;
8272
layout(constant_id = 8) const uint s1 = 1;
8373
layout(constant_id = 9) const uint p0 = 0;
8474
layout(constant_id = 10) const uint p1 = 0;
8575
layout(constant_id = 11) const uint d0 = 1;
8676
layout(constant_id = 12) const uint d1 = 1;
77+
// Kernel spatial sizes
8778
layout(constant_id = 13) const uint KW = 1;
8879
layout(constant_id = 14) const uint KH = 1;
8980

@@ -138,7 +129,7 @@ P,Q=OH,OW
138129
*/
139130

140131
uint32_t B_idx_K = gl_WorkGroupID.x;
141-
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
132+
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
142133

143134
uint32_t T_y = tid / NT_NPQ;
144135
uint32_t T_x = tid % NT_NPQ;
@@ -178,6 +169,10 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T
178169
#endif
179170

180171
void main() {
172+
if (B_idx_NPQ * BS_NPQ >= NPQ) {
173+
return;
174+
}
175+
181176
#ifdef COOPMAT2
182177
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
183178
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 606 additions & 364 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def parse_decls(decls_text):
1919
return decls
2020

2121

22+
def replace_repl_placeholders(variant, template_map):
23+
for repl, code in variant["REPLS"].items():
24+
for key, val in template_map.items():
25+
# Match "key" and avoid matching subsequences using by using \b
26+
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
27+
variant["REPLS"][repl] = code
28+
return variant
29+
30+
2231
def replace_placeholders(shader_text, replacements):
2332
for key, val in replacements.items():
2433
# Match {{KEY}} literally, where KEY is escaped
@@ -71,6 +80,10 @@ def generate_variants(fname, input_dir, output_dir, outfile):
7180
decls_map = parse_decls(extract_block(text, "DECLS"))
7281
except ValueError:
7382
decls_map = {}
83+
try:
84+
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
85+
except ValueError:
86+
templates_map = {}
7487

7588
for fname in sorted(os.listdir(input_dir)):
7689
if fname.endswith(".tmpl"):
@@ -90,9 +103,11 @@ def generate_variants(fname, input_dir, output_dir, outfile):
90103
if key not in decls_map:
91104
raise ValueError(f"DECLS key '{key}' not found.")
92105
decls_code += decls_map[key] + "\n\n"
93-
94106
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
95107
if "REPLS" in variant:
108+
variant = replace_repl_placeholders(variant, templates_map)
109+
final_shader = replace_placeholders(final_shader, variant["REPLS"])
110+
# second run to expand placeholders in repl_template
96111
final_shader = replace_placeholders(final_shader, variant["REPLS"])
97112
final_shader = expand_includes(final_shader, input_dir)
98113

0 commit comments

Comments
 (0)