Skip to content

Commit 4536363

Browse files
authored
ggml WebGPU: add support for quantization types (ggml-org#15440)
* Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments * Work on templating for different types in shaders * Work on shader type generation * Working q4_0 mul_mat and some templating for different types * Add q4_0_f16 matmul and fix device init * Add matmul support for basic quantization types * Add q2_k and q3_k quantization * Add rest of k-quants * Get firt i-quant working * Closer to supporting all i-quants * Support rest of i-quants * Cleanup code * Fix python formatting * debug * Bugfix for memset * Add padding to end of buffers on creation * Simplify bit-shifting * Update usage of StringView
1 parent 32732f2 commit 4536363

File tree

6 files changed

+2145
-245
lines changed

6 files changed

+2145
-245
lines changed

ggml/src/ggml-webgpu/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ add_custom_command(
2020
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
2121
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
2222
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
23-
--input "${SHADER_DIR}"
24-
--output "${SHADER_HEADER}"
23+
--input_dir "${SHADER_DIR}"
24+
--output_file "${SHADER_HEADER}"
2525
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
2626
VERBATIM
2727
)

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 269 additions & 157 deletions
Large diffs are not rendered by default.
Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,85 @@
11
import os
2+
import re
3+
import ast
24
import argparse
35

46

5-
def escape_triple_quotes(wgsl):
6-
# Simple defense in case of embedded """
7-
return wgsl.replace('"""', '\\"""')
7+
def extract_block(text, name):
8+
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
9+
match = re.search(pattern, text, re.DOTALL)
10+
if not match:
11+
raise ValueError(f"Missing block: {name}")
12+
return match.group(1).strip()
813

914

10-
def to_cpp_string_literal(varname, content):
11-
return f'const char* wgsl_{varname} = R"({content})";\n'
15+
def parse_decls(decls_text):
16+
decls = {}
17+
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
18+
decls[name.strip()] = code.strip()
19+
return decls
20+
21+
22+
def replace_placeholders(shader_text, replacements):
23+
for key, val in replacements.items():
24+
# Match {{KEY}} literally, where KEY is escaped
25+
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
26+
shader_text = re.sub(pattern, str(val), shader_text)
27+
return shader_text
28+
29+
30+
def write_shader(shader_name, shader_code, output_dir, outfile):
31+
if output_dir:
32+
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
33+
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
34+
f_out.write(shader_code)
35+
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
36+
37+
38+
def generate_variants(shader_path, output_dir, outfile):
39+
shader_base_name = shader_path.split("/")[-1].split(".")[0]
40+
41+
with open(shader_path, "r", encoding="utf-8") as f:
42+
text = f.read()
43+
44+
try:
45+
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
46+
except ValueError:
47+
write_shader(shader_base_name, text, output_dir, outfile)
48+
else:
49+
decls_map = parse_decls(extract_block(text, "DECLS"))
50+
shader_template = extract_block(text, "SHADER")
51+
52+
for variant in variants:
53+
decls = variant["DECLS"]
54+
decls_code = ""
55+
for key in decls:
56+
if key not in decls_map:
57+
raise ValueError(f"DECLS key '{key}' not found.")
58+
decls_code += decls_map[key] + "\n\n"
59+
60+
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
61+
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
62+
63+
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
64+
write_shader(output_name, final_shader, output_dir, outfile)
1265

1366

1467
def main():
1568
parser = argparse.ArgumentParser()
16-
parser.add_argument('--input', required=True)
17-
parser.add_argument('--output', required=True)
69+
parser.add_argument("--input_dir", required=True)
70+
parser.add_argument("--output_file", required=True)
71+
parser.add_argument("--output_dir")
1872
args = parser.parse_args()
1973

20-
with open(args.output, 'w', encoding='utf-8') as out:
21-
out.write("// Auto-generated shader embedding \n\n")
22-
for fname in sorted(os.listdir(args.input)):
23-
if not fname.endswith('.wgsl'):
24-
continue
25-
shader_path = os.path.join(args.input, fname)
26-
varname = os.path.splitext(fname)[0]
27-
with open(shader_path, 'r', encoding='utf-8') as f:
28-
content = f.read()
29-
content = escape_triple_quotes(content)
30-
out.write(to_cpp_string_literal(varname, content))
31-
out.write('\n')
32-
33-
34-
if __name__ == '__main__':
74+
if args.output_dir:
75+
os.makedirs(args.output_dir, exist_ok=True)
76+
77+
with open(args.output_file, "w", encoding="utf-8") as out:
78+
out.write("// Auto-generated shader embedding\n\n")
79+
for fname in sorted(os.listdir(args.input_dir)):
80+
if fname.endswith(".wgsl"):
81+
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
82+
83+
84+
if __name__ == "__main__":
3585
main()

ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1919
let start = params.offset;
2020
let end = params.offset + params.size;
2121

22-
for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
22+
for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
2323
let byte_index = start + i + j;
24-
if (byte_index + 4u <= end) {
25-
output_buffer[(byte_index >> 2u)] = params.value;
24+
if (byte_index + 4 <= end) {
25+
output_buffer[byte_index >> 2] = params.value;
2626
} else {
2727
// Handle tail (unaligned)
28-
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
28+
for (var k: u32 = 0; k < 4; k++) {
2929
let idx = byte_index + k;
3030
if (idx < end) {
31-
let word_idx = idx >> 2u;
32-
let byte_offset = (idx & 3u) * 8u;
33-
let mask = ~(0xffu << byte_offset);
31+
let word_idx = idx >> 2;
32+
let bit_offset = (idx & 3) * 8u;
33+
let mask = ~(0xffu << bit_offset);
3434
let existing = output_buffer[word_idx];
35-
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
35+
output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
3636
}
3737
}
3838
}

0 commit comments

Comments
 (0)