Skip to content

Commit 2a00311

Browse files
authored
Make sure that check_struct is called where needed (#161)
* Make sure that check_struct is called where needed * consistency * black is the new black * Check for errors in code gen report * implement check_struct * nicer name, without typo * Test that check_struct actually does something * Implement check_struct and apply a few fixes
1 parent 61c1334 commit 2a00311

File tree

16 files changed

+170
-66
lines changed

16 files changed

+170
-66
lines changed

codegen/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def prepare():
2121

2222

2323
def update_api():
24-
""" Update the public API and patch the public-facing API of the backends. """
24+
"""Update the public API and patch the public-facing API of the backends."""
2525

2626
print("## Updating API")
2727

@@ -50,7 +50,7 @@ def update_api():
5050

5151

5252
def update_rs():
53-
""" Update and check the rs backend. """
53+
"""Update and check the rs backend."""
5454

5555
print("## Validating rs.py")
5656

@@ -68,7 +68,7 @@ def update_rs():
6868

6969

7070
def main():
71-
""" Codegen entry point. """
71+
"""Codegen entry point."""
7272

7373
with PrintToFile(os.path.join(lib_dir, "resources", "codegen_report.md")):
7474
print("# Code generatation report")

codegen/apipatcher.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import os
77

8-
from .utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
9-
from .idlparser import get_idl_parser
8+
from codegen.utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
9+
from codegen.idlparser import get_idl_parser
1010

1111

1212
def patch_base_api(code):
@@ -42,7 +42,11 @@ def patch_backend_api(code):
4242
base_api_code = f.read().decode()
4343

4444
# Patch!
45-
for patcher in [CommentRemover(), BackendApiPatcher(base_api_code)]:
45+
for patcher in [
46+
CommentRemover(),
47+
BackendApiPatcher(base_api_code),
48+
StructValidationChecker(),
49+
]:
4650
patcher.apply(code)
4751
code = patcher.dumps()
4852
return code
@@ -53,7 +57,7 @@ class CommentRemover(Patcher):
5357
to prevent accumulating comments.
5458
"""
5559

56-
triggers = "# IDL:", "# FIXME: unknown api"
60+
triggers = "# IDL:", "# FIXME: unknown api", "# FIXME: missing check_struct"
5761

5862
def apply(self, code):
5963
self._init(code)
@@ -174,7 +178,7 @@ def patch_properties(self, classname, i1, i2):
174178
self._apidiffs_from_lines(pre_lines, propname)
175179
if self.prop_is_known(classname, propname):
176180
if "@apidiff.add" in pre_lines:
177-
print(f"Error: apidiff.add for known {classname}.{propname}")
181+
print(f"ERROR: apidiff.add for known {classname}.{propname}")
178182
elif "@apidiff.hide" in pre_lines:
179183
pass # continue as normal
180184
old_line = self.lines[j1]
@@ -207,7 +211,7 @@ def patch_methods(self, classname, i1, i2):
207211
self._apidiffs_from_lines(pre_lines, methodname)
208212
if self.method_is_known(classname, methodname):
209213
if "@apidiff.add" in pre_lines:
210-
print(f"Error: apidiff.add for known {classname}.{methodname}")
214+
print(f"ERROR: apidiff.add for known {classname}.{methodname}")
211215
elif "@apidiff.hide" in pre_lines:
212216
pass # continue as normal
213217
elif "@apidiff.change" in pre_lines:
@@ -443,3 +447,64 @@ def get_required_prop_names(self, classname):
443447
def get_required_method_names(self, classname):
444448
_, methods = self.classes[classname]
445449
return list(name for name in methods.keys() if methods[name][1])
450+
451+
452+
class StructValidationChecker(Patcher):
453+
"""Checks that all structs are vaildated in the methods that have incoming structs."""
454+
455+
def apply(self, code):
456+
self._init(code)
457+
458+
idl = get_idl_parser()
459+
all_structs = set()
460+
ignore_structs = {"Extent3D"}
461+
462+
for classname, i1, i2 in self.iter_classes():
463+
if classname not in idl.classes:
464+
continue
465+
466+
# For each method ...
467+
for methodname, j1, j2 in self.iter_methods(i1 + 1):
468+
code = "\n".join(self.lines[j1 : j2 + 1])
469+
# Get signature and cut it up in words
470+
sig_words = code.partition("(")[2].split("):")[0]
471+
for c in "][(),\"'":
472+
sig_words = sig_words.replace(c, " ")
473+
# Collect incoming structs from signature
474+
method_structs = set()
475+
for word in sig_words.split():
476+
if word.startswith("structs."):
477+
structname = word.partition(".")[2]
478+
method_structs.update(self._get_sub_structs(idl, structname))
479+
all_structs.update(method_structs)
480+
# Collect structs being checked
481+
checked = set()
482+
for line in code.splitlines():
483+
line = line.lstrip()
484+
if line.startswith("check_struct("):
485+
name = line.split("(")[1].split(",")[0].strip('"')
486+
checked.add(name)
487+
# Test that a matching check is done
488+
unchecked = method_structs.difference(checked)
489+
unchecked = list(sorted(unchecked.difference(ignore_structs)))
490+
if (
491+
methodname.endswith("_async")
492+
and f"return self.{methodname[:-7]}" in code
493+
):
494+
pass
495+
elif unchecked:
496+
msg = f"missing check_struct in {methodname}: {unchecked}"
497+
self.insert_line(j1, f"# FIXME: {msg}")
498+
print(f"ERROR: {msg}")
499+
500+
# Test that we did find structs. In case our detection fails for
501+
# some reason, this would probably catch that.
502+
assert len(all_structs) > 10
503+
504+
def _get_sub_structs(self, idl, structname):
505+
structnames = {structname}
506+
for structfield in idl.structs[structname].values():
507+
structname2 = structfield.typename[3:] # remove "GPU"
508+
if structname2 in idl.structs:
509+
structnames.update(self._get_sub_structs(idl, structname2))
510+
return structnames

codegen/hparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def get_h_parser(*, allow_cache=True):
13-
""" Get the global HParser object. """
13+
"""Get the global HParser object."""
1414

1515
# Singleton pattern
1616
global _parser

codegen/idlparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def get_idl_parser(*, allow_cache=True):
19-
""" Get the global IdlParser object. """
19+
"""Get the global IdlParser object."""
2020

2121
# Singleton pattern
2222
global _parser

codegen/rspatcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def apply(self, code):
193193
if name not in hp.functions:
194194
msg = f"unknown C function {name}"
195195
self.insert_line(i, f"{indent}# FIXME: {msg}")
196-
print(f"Error: {msg}")
196+
print(f"ERROR: {msg}")
197197
else:
198198
detected.add(name)
199199
anno = hp.functions[name].replace(name, "f").strip(";")
@@ -302,7 +302,7 @@ def _validate_struct(self, hp, i1, i2):
302302
if struct_name not in hp.structs:
303303
msg = f"unknown C struct {struct_name}"
304304
self.insert_line(i1, f"{indent}# FIXME: {msg}")
305-
print(f"Error: {msg}")
305+
print(f"ERROR: {msg}")
306306
return
307307
else:
308308
struct = hp.structs[struct_name]
@@ -322,7 +322,7 @@ def _validate_struct(self, hp, i1, i2):
322322
if key not in struct:
323323
msg = f"unknown C struct field {struct_name}.{key}"
324324
self.insert_line(i1 + j, f"{indent}# FIXME: {msg}")
325-
print(f"Error: {msg}")
325+
print(f"ERROR: {msg}")
326326

327327
# Insert comments for unused keys
328328
more_lines = []

codegen/tests/test_codegen_z.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,15 @@ def test_that_code_is_up_to_date():
7171
print("Codegen check ok!")
7272

7373

74+
def test_that_codegen_report_has_no_errors():
75+
filename = os.path.join(lib_dir, "resources", "codegen_report.md")
76+
with open(filename, "rb") as f:
77+
text = f.read().decode()
78+
79+
# The codegen uses a prefix "ERROR:" for unacceptable things.
80+
# All caps, some function names may contain the name "error".
81+
assert "ERROR" not in text
82+
83+
7484
if __name__ == "__main__":
7585
test_that_code_is_up_to_date()

codegen/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def print(*args, **kwargs):
5555

5656

5757
class PrintToFile:
58-
""" Context manager to print to file. """
58+
"""Context manager to print to file."""
5959

6060
def __init__(self, f):
6161
if isinstance(f, str):

examples/cube_glfw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def draw_frame():
389389

390390

391391
def simple_event_loop():
392-
""" A real simple event loop, but it keeps the CPU busy. """
392+
"""A real simple event loop, but it keeps the CPU busy."""
393393
while update_glfw_canvasses():
394394
glfw.poll_events()
395395

examples/triangle_glfw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121

2222

2323
def simple_event_loop():
24-
""" A real simple event loop, but it keeps the CPU busy. """
24+
"""A real simple event loop, but it keeps the CPU busy."""
2525
while update_glfw_canvasses():
2626
glfw.poll_events()
2727

2828

2929
def better_event_loop(max_fps=100):
30-
""" A simple event loop that schedules draws. """
30+
"""A simple event loop that schedules draws."""
3131
td = 1 / max_fps
3232
while update_glfw_canvasses():
3333
# Determine next time to draw

tests/test_compute.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,13 @@ def compute_shader(
172172
)
173173
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
174174

175+
# Create and run the pipeline, fail - test check_struct
176+
with raises(ValueError):
177+
compute_pipeline = device.create_compute_pipeline(
178+
layout=pipeline_layout,
179+
compute={"module": cshader, "entry_point": "main", "foo": 42},
180+
)
181+
175182
# Create and run the pipeline
176183
compute_pipeline = device.create_compute_pipeline(
177184
layout=pipeline_layout,
@@ -259,7 +266,7 @@ def compute_shader(
259266
compute_with_buffers({0: in1}, {0: c_int32 * 100}, compute_shader, n=-1)
260267

261268
with raises(TypeError): # invalid shader
262-
compute_with_buffers({0: in1}, {0: c_int32 * 100}, "not a shader")
269+
compute_with_buffers({0: in1}, {0: c_int32 * 100}, {"not", "a", "shader"})
263270

264271

265272
if __name__ == "__main__":

0 commit comments

Comments
 (0)