Skip to content

Commit 2845a51

Browse files
authored
Promise / Future API with .then() (#739)
* WIP Promise / Future API with .then() * Fix leftover dict default * update all methods * The _sync classes are implemented in the base classes * Can remove some proxy methods in _api.py now * Adjust a test in codegen * Add docs * wip * wip * cleaner * tweaks * remove accidentally added * Move implementation of promise to separate module * codegen * Change chaining * Also keep working async without a loop * tweaks * more docs and tweaks * add catch method * add to classes, so we can override behaviour in backends * Improve how GPUPromise is documented * Tweak signature * Add tests * fix names * fix more names * Fix tests * cleanup after self-review * restore cube example * tweak loop in classes.py * add back incremental sleeping * codegen * fix typo * tweak example * fix typo
1 parent 9298807 commit 2845a51

File tree

14 files changed

+1252
-419
lines changed

14 files changed

+1252
-419
lines changed

codegen/apipatcher.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ def patch_base_api(code):
3838
idl = get_idl_parser()
3939

4040
# Write __all__
41+
extra_public_classes = ["GPUPromise"]
42+
all_public_classes = [*idl.classes.keys(), *extra_public_classes]
4143
part1, found_all, part2 = code.partition("\n__all__ =")
4244
if found_all:
4345
part2 = part2.split("]", 1)[-1]
4446
line = "\n__all__ = ["
45-
line += ", ".join(f'"{name}"' for name in sorted(idl.classes.keys()))
47+
line += ", ".join(f'"{name}"' for name in sorted(all_public_classes))
4648
line += "]"
4749
code = part1 + line + part2
4850

@@ -158,14 +160,22 @@ def patch_classes(self):
158160
for classname, i1, i2 in self.iter_classes():
159161
seen_classes.add(classname)
160162
self._apidiffs = set()
163+
pre_lines = "\n".join(self.lines[i1 - 3 : i1])
164+
self._apidiffs_from_lines(pre_lines, classname)
161165
if self.class_is_known(classname):
166+
if "@apidiff.add" in pre_lines:
167+
print(f"ERROR: apidiff.add for known {classname}")
168+
elif "@apidiff.hide" in pre_lines:
169+
pass # continue as normal
162170
old_line = self.lines[i1]
163171
new_line = self.get_class_def(classname)
164172
if old_line != new_line:
165173
fixme_line = "# FIXME: was " + old_line.split("class ", 1)[-1]
166174
self.replace_line(i1, f"{fixme_line}\n{new_line}")
167175
self.patch_properties(classname, i1 + 1, i2)
168176
self.patch_methods(classname, i1 + 1, i2)
177+
elif "@apidiff.add" in pre_lines:
178+
pass
169179
else:
170180
msg = f"unknown api: class {classname}"
171181
self.insert_line(i1, "# FIXME: " + msg)
@@ -422,14 +432,15 @@ def get_property_def(self, classname, propname) -> str:
422432
print(f"Error resolving type for {classname}.{propname}: {err}")
423433
prop_type = None
424434

435+
if prop_type and propname.endswith("_async"):
436+
prop_type = f"GPUPromise[{prop_type}]"
437+
425438
line = (
426439
"def "
427440
+ to_snake_case(propname)
428441
+ "(self)"
429442
+ f"{f' -> {prop_type}' if prop_type else ''}:"
430443
)
431-
if propname.endswith("_async"):
432-
line = "async " + line
433444
return " " + line
434445

435446
def get_method_def(self, classname, methodname) -> str:
@@ -439,8 +450,6 @@ def get_method_def(self, classname, methodname) -> str:
439450

440451
# Construct preamble
441452
preamble = "def " + to_snake_case(methodname) + "("
442-
if methodname.endswith("_async"):
443-
preamble = "async " + preamble
444453

445454
# Get arg names and types
446455
idl_line = functions[name_idl]
@@ -455,6 +464,8 @@ def get_method_def(self, classname, methodname) -> str:
455464
return_type = None
456465
if return_type:
457466
return_type = self.idl.resolve_type(return_type)
467+
if methodname.endswith("_async"):
468+
return_type = f"GPUPromise[{return_type}]"
458469

459470
# If one arg that is a dict, flatten dict to kwargs
460471
if len(args) == 1 and args[0].typename.endswith(
@@ -601,6 +612,8 @@ def __init__(self, base_api_code):
601612
pre_lines = "\n".join(p1.lines[j1 - 3 : j1])
602613
if "@apidiff.hide" in pre_lines:
603614
continue # method (currently) not part of our API
615+
if methodname.endswith("_sync"):
616+
continue # the base class implements _sync versions (using promise.sync_wait())
604617
body = "\n".join(p1.lines[j1 + 1 : j2 + 1])
605618
must_overload = "raise NotImplementedError()" in body
606619
methods[methodname] = p1.lines[j1], must_overload
@@ -690,7 +703,7 @@ def apply(self, code):
690703
defer_func_name = "_" + methodname
691704
defer_line_starts = (
692705
f"return self.{defer_func_name[:-7]}",
693-
f"awaitable = self.{defer_func_name[:-7]}",
706+
f"promise = self.{defer_func_name[:-7]}",
694707
)
695708
this_method_defers = any(
696709
line.strip().startswith(defer_line_starts)
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""Test some aspects of the generated code."""
22

33
from codegen.files import read_file
4+
from codegen.utils import format_code
45

56

67
def test_async_methods_and_props():
7-
# Test that only and all async methods are suffixed with '_async'
8+
# Test that async methods return a promise
89

910
for fname in ["_classes.py", "backends/wgpu_native/_api.py"]:
10-
code = read_file(fname)
11+
code = format_code(read_file(fname), singleline=True)
1112
for line in code.splitlines():
1213
line = line.strip()
1314
if line.startswith("def "):
14-
assert not line.endswith("_async"), line
15-
elif line.startswith("async def "):
16-
name = line.split("def", 1)[1].split("(")[0].strip()
17-
assert name.endswith("_async"), line
15+
res_type = line.split("->")[-1].strip()
16+
if "_async(" in line:
17+
assert res_type.startswith("GPUPromise")
18+
else:
19+
assert "GPUPromise" not in line

docs/_templates/wgpu_class_layout.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
.. currentmodule:: {{ module }}
44

5+
{% if objname == "GPUPromise" %}
6+
57
.. autoclass:: {{ objname }}
68
:members:
9+
:inherited-members:
710
:show-inheritance:
11+
12+
{% else %}
13+
14+
.. autoclass:: {{ objname }}
15+
:members:
16+
:show-inheritance:
17+
18+
{% endif %}

docs/wgpu.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ come in two flafours:
6565
not part of the WebGPU spec, and as a consequence, code that uses this method
6666
is less portable (to e.g. pyodide/pyscript).
6767

68+
The async methods return a :class:`GPUPromise`, which resolves to the actual result. You can wait for it to resolve in three ways:
69+
70+
* In async code, use ``await promise``.
71+
* In sync code, use ``promise.then(callback)`` to register a callback that is executed when the promise resolves.
72+
* In sync code, you can use ``promise.sync_wait()``. This is similar to the ``_sync()`` flavour mentioned above (it makes your code less portable).
73+
6874

6975
Canvas API
7076
----------
@@ -241,6 +247,7 @@ List of GPU classes
241247
~GPUPipelineBase
242248
~GPUPipelineError
243249
~GPUPipelineLayout
250+
~GPUPromise
244251
~GPUQuerySet
245252
~GPUQueue
246253
~GPURenderBundle

examples/cube.py

Lines changed: 39 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def setup_drawing_sync(
3434

3535
adapter = wgpu.gpu.request_adapter_sync(power_preference=power_preference)
3636
device = adapter.request_device_sync(
37-
required_limits=limits, label="Cube Example device"
37+
label="Cube Example device",
38+
required_limits=limits,
3839
)
3940

4041
pipeline_layout, uniform_buffer, bind_group = create_pipeline_layout(device)
@@ -43,7 +44,7 @@ def setup_drawing_sync(
4344
render_pipeline = device.create_render_pipeline(**pipeline_kwargs)
4445

4546
return get_draw_function(
46-
canvas, device, render_pipeline, uniform_buffer, bind_group, asynchronous=False
47+
canvas, device, render_pipeline, uniform_buffer, bind_group
4748
)
4849

4950

@@ -53,10 +54,10 @@ async def setup_drawing_async(canvas, limits=None):
5354
The given canvas must implement WgpuCanvasInterface, but nothing more.
5455
Returns the draw function.
5556
"""
56-
5757
adapter = await wgpu.gpu.request_adapter_async(power_preference="high-performance")
58+
5859
device = await adapter.request_device_async(
59-
required_limits=limits, label="Cube Example async device"
60+
label="Cube Example async device", required_limits=limits
6061
)
6162

6263
pipeline_layout, uniform_buffer, bind_group = create_pipeline_layout(device)
@@ -65,7 +66,19 @@ async def setup_drawing_async(canvas, limits=None):
6566
render_pipeline = await device.create_render_pipeline_async(**pipeline_kwargs)
6667

6768
return get_draw_function(
68-
canvas, device, render_pipeline, uniform_buffer, bind_group, asynchronous=True
69+
canvas, device, render_pipeline, uniform_buffer, bind_group
70+
)
71+
72+
73+
def get_drawing_func(canvas, device):
74+
pipeline_layout, uniform_buffer, bind_group = create_pipeline_layout(device)
75+
pipeline_kwargs = get_render_pipeline_kwargs(canvas, device, pipeline_layout)
76+
77+
render_pipeline = device.create_render_pipeline(**pipeline_kwargs)
78+
# render_pipeline = device.create_render_pipeline(**pipeline_kwargs)
79+
80+
return get_draw_function(
81+
canvas, device, render_pipeline, uniform_buffer, bind_group
6982
)
7083

7184

@@ -242,8 +255,6 @@ def get_draw_function(
242255
render_pipeline: wgpu.GPURenderPipeline,
243256
uniform_buffer: wgpu.GPUBuffer,
244257
bind_group: wgpu.GPUBindGroup,
245-
*,
246-
asynchronous: bool,
247258
):
248259
# Create vertex buffer, and upload data
249260
vertex_buffer = device.create_buffer_with_data(
@@ -288,48 +299,8 @@ def update_transform():
288299
)
289300
uniform_data["transform"] = rot2 @ rot1 @ ortho
290301

291-
def upload_uniform_buffer_sync():
292-
if True:
293-
tmp_buffer = uniform_buffer.copy_buffer
294-
tmp_buffer.map_sync(wgpu.MapMode.WRITE)
295-
tmp_buffer.write_mapped(uniform_data)
296-
tmp_buffer.unmap()
297-
else:
298-
tmp_buffer = device.create_buffer_with_data(
299-
data=uniform_data, usage=wgpu.BufferUsage.COPY_SRC
300-
)
301-
command_encoder = device.create_command_encoder(
302-
label="Cube Example uniform buffer upload command encoder"
303-
)
304-
command_encoder.copy_buffer_to_buffer(
305-
tmp_buffer, 0, uniform_buffer, 0, uniform_data.nbytes
306-
)
307-
device.queue.submit(
308-
[
309-
command_encoder.finish(
310-
label="Cube Example uniform buffer upload command buffer"
311-
)
312-
]
313-
)
314-
315-
async def upload_uniform_buffer_async():
316-
tmp_buffer = uniform_buffer.copy_buffer
317-
await tmp_buffer.map_async(wgpu.MapMode.WRITE)
318-
tmp_buffer.write_mapped(uniform_data)
319-
tmp_buffer.unmap()
320-
command_encoder = device.create_command_encoder(
321-
label="Cube Example uniform buffer upload async command encoder"
322-
)
323-
command_encoder.copy_buffer_to_buffer(
324-
tmp_buffer, 0, uniform_buffer, 0, uniform_data.nbytes
325-
)
326-
device.queue.submit(
327-
[
328-
command_encoder.finish(
329-
label="Cube Example uniform buffer upload async command buffer"
330-
)
331-
]
332-
)
302+
def upload_uniform_buffer():
303+
device.queue.write_buffer(uniform_buffer, 0, uniform_data)
333304

334305
def draw_frame():
335306
current_texture_view: wgpu.GPUTextureView = (
@@ -367,20 +338,12 @@ def draw_frame():
367338
[command_encoder.finish(label="Cube Example render pass command buffer")]
368339
)
369340

370-
def draw_frame_sync():
341+
def draw_func():
371342
update_transform()
372-
upload_uniform_buffer_sync()
343+
upload_uniform_buffer()
373344
draw_frame()
374345

375-
async def draw_frame_async():
376-
update_transform()
377-
await upload_uniform_buffer_async()
378-
draw_frame()
379-
380-
if asynchronous:
381-
return draw_frame_async
382-
else:
383-
return draw_frame_sync
346+
return draw_func
384347

385348

386349
# %% WGSL
@@ -509,6 +472,7 @@ async def draw_frame_async():
509472
for a in wgpu.gpu.enumerate_adapters_sync():
510473
print(a.summary)
511474

475+
512476
if __name__ == "__main__":
513477
canvas = RenderCanvas(
514478
size=(640, 480),
@@ -517,6 +481,19 @@ async def draw_frame_async():
517481
max_fps=60,
518482
vsync=True,
519483
)
520-
draw_frame = setup_drawing_sync(canvas)
521-
canvas.request_draw(draw_frame)
484+
485+
# Pick one
486+
487+
if True:
488+
# Async
489+
@loop.add_task
490+
async def init():
491+
draw_frame = await setup_drawing_async(canvas)
492+
canvas.request_draw(draw_frame)
493+
else:
494+
# Sync
495+
draw_frame = setup_drawing_sync(canvas)
496+
canvas.request_draw(draw_frame)
497+
498+
# loop.add_task(poller)
522499
loop.run()

tests/test_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ def test_basic_api():
2222

2323
code1 = wgpu.GPU.request_adapter_sync.__code__
2424
code2 = wgpu.GPU.request_adapter_async.__code__
25+
varnames1 = set(code1.co_varnames) - {"gpu", "promise", "loop"}
26+
varnames2 = set(code2.co_varnames) - {"gpu", "promise", "loop"}
2527
# nargs1 = code1.co_argcount + code1.co_kwonlyargcount
26-
assert code1.co_varnames == code2.co_varnames
28+
assert varnames1 == varnames2
2729

2830
assert repr(wgpu.classes.GPU()).startswith(
2931
"<wgpu.GPU "
@@ -98,7 +100,7 @@ def test_enums_and_flags_and_structs():
98100

99101
def test_base_wgpu_api():
100102
# Fake a device and an adapter
101-
adapter = wgpu.GPUAdapter(None, set(), {}, wgpu.GPUAdapterInfo({}))
103+
adapter = wgpu.GPUAdapter(None, set(), {}, wgpu.GPUAdapterInfo({}), None)
102104
queue = wgpu.GPUQueue("", None, None)
103105
device = wgpu.GPUDevice("device08", -1, adapter, {42, 43}, {}, queue)
104106

0 commit comments

Comments
 (0)