|
1 | 1 | import os
|
| 2 | +import re |
| 3 | +import ast |
2 | 4 | import argparse
|
3 | 5 |
|
4 | 6 |
|
5 |
| -def escape_triple_quotes(wgsl): |
6 |
| - # Simple defense in case of embedded """ |
7 |
| - return wgsl.replace('"""', '\\"""') |
| 7 | +def extract_block(text, name): |
| 8 | + pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' |
| 9 | + match = re.search(pattern, text, re.DOTALL) |
| 10 | + if not match: |
| 11 | + raise ValueError(f"Missing block: {name}") |
| 12 | + return match.group(1).strip() |
8 | 13 |
|
9 | 14 |
|
10 |
| -def to_cpp_string_literal(varname, content): |
11 |
| - return f'const char* wgsl_{varname} = R"({content})";\n' |
| 15 | +def parse_decls(decls_text): |
| 16 | + decls = {} |
| 17 | + for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): |
| 18 | + decls[name.strip()] = code.strip() |
| 19 | + return decls |
| 20 | + |
| 21 | + |
| 22 | +def replace_placeholders(shader_text, replacements): |
| 23 | + for key, val in replacements.items(): |
| 24 | + # Match {{KEY}} literally, where KEY is escaped |
| 25 | + pattern = r'{{\s*' + re.escape(key) + r'\s*}}' |
| 26 | + shader_text = re.sub(pattern, str(val), shader_text) |
| 27 | + return shader_text |
| 28 | + |
| 29 | + |
| 30 | +def write_shader(shader_name, shader_code, output_dir, outfile): |
| 31 | + if output_dir: |
| 32 | + wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") |
| 33 | + with open(wgsl_filename, "w", encoding="utf-8") as f_out: |
| 34 | + f_out.write(shader_code) |
| 35 | + outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') |
| 36 | + |
| 37 | + |
| 38 | +def generate_variants(shader_path, output_dir, outfile): |
| 39 | + shader_base_name = shader_path.split("/")[-1].split(".")[0] |
| 40 | + |
| 41 | + with open(shader_path, "r", encoding="utf-8") as f: |
| 42 | + text = f.read() |
| 43 | + |
| 44 | + try: |
| 45 | + variants = ast.literal_eval(extract_block(text, "VARIANTS")) |
| 46 | + except ValueError: |
| 47 | + write_shader(shader_base_name, text, output_dir, outfile) |
| 48 | + else: |
| 49 | + decls_map = parse_decls(extract_block(text, "DECLS")) |
| 50 | + shader_template = extract_block(text, "SHADER") |
| 51 | + |
| 52 | + for variant in variants: |
| 53 | + decls = variant["DECLS"] |
| 54 | + decls_code = "" |
| 55 | + for key in decls: |
| 56 | + if key not in decls_map: |
| 57 | + raise ValueError(f"DECLS key '{key}' not found.") |
| 58 | + decls_code += decls_map[key] + "\n\n" |
| 59 | + |
| 60 | + shader_variant = replace_placeholders(shader_template, variant["REPLS"]) |
| 61 | + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) |
| 62 | + |
| 63 | + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) |
| 64 | + write_shader(output_name, final_shader, output_dir, outfile) |
12 | 65 |
|
13 | 66 |
|
14 | 67 | def main():
|
15 | 68 | parser = argparse.ArgumentParser()
|
16 |
| - parser.add_argument('--input', required=True) |
17 |
| - parser.add_argument('--output', required=True) |
| 69 | + parser.add_argument("--input_dir", required=True) |
| 70 | + parser.add_argument("--output_file", required=True) |
| 71 | + parser.add_argument("--output_dir") |
18 | 72 | args = parser.parse_args()
|
19 | 73 |
|
20 |
| - with open(args.output, 'w', encoding='utf-8') as out: |
21 |
| - out.write("// Auto-generated shader embedding \n\n") |
22 |
| - for fname in sorted(os.listdir(args.input)): |
23 |
| - if not fname.endswith('.wgsl'): |
24 |
| - continue |
25 |
| - shader_path = os.path.join(args.input, fname) |
26 |
| - varname = os.path.splitext(fname)[0] |
27 |
| - with open(shader_path, 'r', encoding='utf-8') as f: |
28 |
| - content = f.read() |
29 |
| - content = escape_triple_quotes(content) |
30 |
| - out.write(to_cpp_string_literal(varname, content)) |
31 |
| - out.write('\n') |
32 |
| - |
33 |
| - |
34 |
| -if __name__ == '__main__': |
| 74 | + if args.output_dir: |
| 75 | + os.makedirs(args.output_dir, exist_ok=True) |
| 76 | + |
| 77 | + with open(args.output_file, "w", encoding="utf-8") as out: |
| 78 | + out.write("// Auto-generated shader embedding\n\n") |
| 79 | + for fname in sorted(os.listdir(args.input_dir)): |
| 80 | + if fname.endswith(".wgsl"): |
| 81 | + generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out) |
| 82 | + |
| 83 | + |
| 84 | +if __name__ == "__main__": |
35 | 85 | main()
|
0 commit comments