Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6a6135c
Begin work on set_rows
reeselevine Aug 5, 2025
b2dbfcd
Work on set rows
reeselevine Aug 5, 2025
248f7a5
Add error buffers for reporting unsupported SET_ROWS indices
reeselevine Aug 6, 2025
4ad0986
Remove extra comments
reeselevine Aug 6, 2025
6355137
Work on templating for different types in shaders
reeselevine Aug 7, 2025
831ea3c
Work on shader type generation
reeselevine Aug 7, 2025
688b51d
Working q4_0 mul_mat and some templating for different types
reeselevine Aug 11, 2025
1aa40f1
Add q4_0_f16 matmul and fix device init
reeselevine Aug 12, 2025
c3611f9
Add matmul support for basic quantization types
reeselevine Aug 13, 2025
de4da87
Add q2_k and q3_k quantization
reeselevine Aug 14, 2025
d76e562
Add rest of k-quants
reeselevine Aug 14, 2025
e2380e2
Get firt i-quant working
reeselevine Aug 15, 2025
2a3b9ee
Closer to supporting all i-quants
reeselevine Aug 17, 2025
57c26b1
Support rest of i-quants
reeselevine Aug 19, 2025
7a2ae48
Merge remote-tracking branch 'origin/master' into types
reeselevine Aug 19, 2025
51252f0
Cleanup code
reeselevine Aug 19, 2025
985508e
Merge pull request #2 from reeselevine/types
reeselevine Aug 20, 2025
6552e2e
Fix python formatting
reeselevine Aug 20, 2025
65bebd3
debug
reeselevine Aug 20, 2025
16df269
Bugfix for memset
reeselevine Aug 20, 2025
10babfd
Add padding to end of buffers on creation
reeselevine Aug 21, 2025
7a323b0
Simplify bit-shifting
reeselevine Aug 21, 2025
d1b0ffe
Merge pull request #3 from reeselevine/fixes
reeselevine Aug 21, 2025
d690303
Update usage of StringView
reeselevine Aug 21, 2025
1fcc404
Merge remote-tracking branch 'upstream/master'
reeselevine Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/src/ggml-webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
--input "${SHADER_DIR}"
--output "${SHADER_HEADER}"
--input_dir "${SHADER_DIR}"
--output_file "${SHADER_HEADER}"
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
VERBATIM
)
Expand Down
426 changes: 269 additions & 157 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

94 changes: 72 additions & 22 deletions ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,85 @@
import os
import re
import ast
import argparse


def escape_triple_quotes(wgsl):
# Simple defense in case of embedded """
return wgsl.replace('"""', '\\"""')
def extract_block(text, name):
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise ValueError(f"Missing block: {name}")
return match.group(1).strip()


def to_cpp_string_literal(varname, content):
return f'const char* wgsl_{varname} = R"({content})";\n'
def parse_decls(decls_text):
decls = {}
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
decls[name.strip()] = code.strip()
return decls


def replace_placeholders(shader_text, replacements):
for key, val in replacements.items():
# Match {{KEY}} literally, where KEY is escaped
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text


def write_shader(shader_name, shader_code, output_dir, outfile):
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
f_out.write(shader_code)
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')


def generate_variants(shader_path, output_dir, outfile):
shader_base_name = shader_path.split("/")[-1].split(".")[0]

with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()

try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
else:
decls_map = parse_decls(extract_block(text, "DECLS"))
shader_template = extract_block(text, "SHADER")

for variant in variants:
decls = variant["DECLS"]
decls_code = ""
for key in decls:
if key not in decls_map:
raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n"

shader_variant = replace_placeholders(shader_template, variant["REPLS"])
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)

output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
write_shader(output_name, final_shader, output_dir, outfile)


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--output', required=True)
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--output_dir")
args = parser.parse_args()

with open(args.output, 'w', encoding='utf-8') as out:
out.write("// Auto-generated shader embedding \n\n")
for fname in sorted(os.listdir(args.input)):
if not fname.endswith('.wgsl'):
continue
shader_path = os.path.join(args.input, fname)
varname = os.path.splitext(fname)[0]
with open(shader_path, 'r', encoding='utf-8') as f:
content = f.read()
content = escape_triple_quotes(content)
out.write(to_cpp_string_literal(varname, content))
out.write('\n')


if __name__ == '__main__':
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)

with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)


if __name__ == "__main__":
main()
16 changes: 8 additions & 8 deletions ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let start = params.offset;
let end = params.offset + params.size;

for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) {
for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
let byte_index = start + i + j;
if (byte_index + 4u <= end) {
output_buffer[(byte_index >> 2u)] = params.value;
if (byte_index + 4 <= end) {
output_buffer[byte_index >> 2] = params.value;
} else {
// Handle tail (unaligned)
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
for (var k: u32 = 0; k < 4; k++) {
let idx = byte_index + k;
if (idx < end) {
let word_idx = idx >> 2u;
let byte_offset = (idx & 3u) * 8u;
let mask = ~(0xffu << byte_offset);
let word_idx = idx >> 2;
let bit_offset = (idx & 3) * 8u;
let mask = ~(0xffu << bit_offset);
let existing = output_buffer[word_idx];
output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset);
output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
}
}
}
Expand Down
Loading
Loading