Skip to content

Commit c182ca9

Browse files
committed
Refactor CABI to express post-return interleaving through thunks (no change in behavior)
1 parent ec6f3ba commit c182ca9

File tree

3 files changed

+51
-55
lines changed

3 files changed

+51
-55
lines changed

design/mvp/CanonicalABI.md

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,7 +1537,7 @@ validation specifies:
15371537
* if a `post-return` is present, it has type `(func (param flatten_functype($ft).results))`
15381538

15391539
When instantiating component instance `$inst`:
1540-
* Define `$f` to be the closure `lambda args: canon_lift($opts, $inst, $callee, $ft, args)`
1540+
* Define `$f` to be the partially-bound closure `canon_lift($opts, $inst, $callee, $ft)`
15411541

15421542
Thus, `$f` captures `$opts`, `$inst`, `$callee` and `$ft` in a closure which
15431543
can be subsequently exported or passed into a child instance (via `with`). If
@@ -1555,20 +1555,18 @@ component*.
15551555

15561556
Given the above closure arguments, `canon_lift` is defined:
15571557
```python
1558-
def canon_lift(opts, inst, callee, ft, args):
1558+
def canon_lift(opts, inst, callee, ft, start_thunk, return_thunk):
15591559
export_call = ExportCall(opts, inst)
15601560
trap_if(not inst.may_enter)
15611561

1562-
flat_args = lower_values(export_call, MAX_FLAT_PARAMS, args, ft.param_types())
1562+
flat_args = lower_values(export_call, MAX_FLAT_PARAMS, start_thunk(), ft.param_types())
15631563
flat_results = call_and_trap_on_throw(callee, flat_args)
1564-
results = lift_values(export_call, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types())
1564+
return_thunk(lift_values(export_call, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
15651565

1566-
def post_return():
1567-
if opts.post_return is not None:
1568-
call_and_trap_on_throw(opts.post_return, flat_results)
1569-
export_call.exit()
1566+
if opts.post_return is not None:
1567+
call_and_trap_on_throw(opts.post_return, flat_results)
15701568

1571-
return (results, post_return)
1569+
export_call.exit()
15721570

15731571
def call_and_trap_on_throw(callee, args):
15741572
try:
@@ -1581,10 +1579,11 @@ boundaries. Thus, if a component wishes to signal an error, it must use some
15811579
sort of explicit type such as `result` (whose `error` case particular language
15821580
bindings may choose to map to and from exceptions).
15831581

1584-
The contract assumed by `canon_lift` (and ensured by `canon_lower` below) is
1585-
that the caller of `canon_lift` *must* call `post_return` right after lowering
1586-
`result`. This ensures that `post_return` can be used to perform cleanup
1587-
actions after the lowering is complete.
1582+
The `start_thunk` and `return_thunk` are used to model the interleaving of
1583+
reading arguments out of the caller's stack and memory and writing results
1584+
back into the caller's stack and memory. After the results have been copied
1585+
from the callee's memory into the caller's memory, the callee's `post_return`
1586+
function is called to allow the callee to reclaim any memory.
15881587

15891588

15901589
### `canon lower`
@@ -1600,11 +1599,9 @@ where `$callee` has type `$ft`, validation specifies:
16001599
* there is no `post-return` in `$opts`
16011600

16021601
When instantiating component instance `$inst`:
1603-
* Define `$f` to be the closure: `lambda args: canon_lower($opts, $inst, $callee, $ft, args)`
1602+
* Define `$f` to be the partially-bound closure: `canon_lower($opts, $inst, $callee, $ft)`
16041603

1605-
Thus, from the perspective of Core WebAssembly, `$f` is a [function instance]
1606-
containing a `hostfunc` that closes over `$opts`, `$inst`, `$callee` and `$ft`
1607-
and, when called from Core WebAssembly code, calls `canon_lower`, which is defined as:
1604+
where `canon_lower` is defined:
16081605
```python
16091606
def canon_lower(opts, inst, callee, calling_import, ft, flat_args):
16101607
import_call = ImportCall(opts, inst)
@@ -1615,30 +1612,23 @@ def canon_lower(opts, inst, callee, calling_import, ft, flat_args):
16151612
inst.may_enter = False
16161613

16171614
flat_args = CoreValueIter(flat_args)
1618-
args = lift_values(import_call, MAX_FLAT_PARAMS, flat_args, ft.param_types())
1615+
flat_results = None
16191616

1620-
results, post_return = callee(args)
1617+
def start_thunk():
1618+
return lift_values(import_call, MAX_FLAT_PARAMS, flat_args, ft.param_types())
16211619

1622-
flat_results = lower_values(import_call, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
1620+
def return_thunk(results):
1621+
nonlocal flat_results
1622+
flat_results = lower_values(import_call, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
16231623

1624-
post_return()
1625-
import_call.exit()
1624+
callee(start_thunk, return_thunk)
16261625

16271626
if calling_import:
16281627
inst.may_enter = True
16291628

1629+
import_call.exit()
16301630
return flat_results
16311631
```
1632-
The definitions of `canon_lift` and `canon_lower` are mostly symmetric
1633-
(swapping lifting and lowering), with a few exceptions (in `flatten_functype`,
1634-
as defined above):
1635-
* The caller does not need a `post-return` function since the Core WebAssembly
1636-
caller simply regains control when `canon_lower` returns, allowing it to free
1637-
(or not) any memory passed as `flat_args`.
1638-
* When handling the too-many-flat-values case, instead of relying on `realloc`,
1639-
the caller pass in a pointer to caller-allocated memory as a final
1640-
`i32` parameter.
1641-
16421632
Since any cross-component call necessarily transits through a statically-known
16431633
`canon_lower`+`canon_lift` call pair, an AOT compiler can fuse `canon_lift` and
16441634
`canon_lower` into a single, efficient trampoline. In the future this may allow

design/mvp/canonical-abi/definitions.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,20 +1125,18 @@ def lower_values(cx, max_flat, vs, ts, out_param = None):
11251125

11261126
### `canon lift`
11271127

1128-
def canon_lift(opts, inst, callee, ft, args):
1128+
def canon_lift(opts, inst, callee, ft, start_thunk, return_thunk):
11291129
export_call = ExportCall(opts, inst)
11301130
trap_if(not inst.may_enter)
11311131

1132-
flat_args = lower_values(export_call, MAX_FLAT_PARAMS, args, ft.param_types())
1132+
flat_args = lower_values(export_call, MAX_FLAT_PARAMS, start_thunk(), ft.param_types())
11331133
flat_results = call_and_trap_on_throw(callee, flat_args)
1134-
results = lift_values(export_call, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types())
1134+
return_thunk(lift_values(export_call, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
11351135

1136-
def post_return():
1137-
if opts.post_return is not None:
1138-
call_and_trap_on_throw(opts.post_return, flat_results)
1139-
export_call.exit()
1136+
if opts.post_return is not None:
1137+
call_and_trap_on_throw(opts.post_return, flat_results)
11401138

1141-
return (results, post_return)
1139+
export_call.exit()
11421140

11431141
def call_and_trap_on_throw(callee, args):
11441142
try:
@@ -1157,18 +1155,21 @@ def canon_lower(opts, inst, callee, calling_import, ft, flat_args):
11571155
inst.may_enter = False
11581156

11591157
flat_args = CoreValueIter(flat_args)
1160-
args = lift_values(import_call, MAX_FLAT_PARAMS, flat_args, ft.param_types())
1158+
flat_results = None
11611159

1162-
results, post_return = callee(args)
1160+
def start_thunk():
1161+
return lift_values(import_call, MAX_FLAT_PARAMS, flat_args, ft.param_types())
11631162

1164-
flat_results = lower_values(import_call, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
1163+
def return_thunk(results):
1164+
nonlocal flat_results
1165+
flat_results = lower_values(import_call, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
11651166

1166-
post_return()
1167-
import_call.exit()
1167+
callee(start_thunk, return_thunk)
11681168

11691169
if calling_import:
11701170
inst.may_enter = True
11711171

1172+
import_call.exit()
11721173
return flat_results
11731174

11741175
### `canon resource.new`

design/mvp/canonical-abi/run_tests.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import definitions
2+
from functools import partial
23
from definitions import *
34

45
def equal_modulo_string_encoding(s, t):
@@ -340,7 +341,7 @@ def test_roundtrip(t, v):
340341
callee_heap = Heap(1000)
341342
callee_opts = mk_opts(callee_heap.memory, 'utf8', callee_heap.realloc, lambda x: () )
342343
callee_inst = ComponentInstance()
343-
lifted_callee = lambda args: canon_lift(callee_opts, callee_inst, callee, ft, args)
344+
lifted_callee = partial(canon_lift, callee_opts, callee_inst, callee, ft)
344345

345346
caller_heap = Heap(1000)
346347
caller_opts = mk_opts(caller_heap.memory, 'utf8', caller_heap.realloc)
@@ -379,11 +380,12 @@ def dtor(x):
379380
rt2 = ResourceType(inst, dtor) # only usable in exports
380381
opts = mk_opts()
381382

382-
def host_import(args):
383+
def host_import(start_thunk, return_thunk):
384+
args = start_thunk()
383385
assert(len(args) == 2)
384386
assert(args[0] == 42)
385387
assert(args[1] == 44)
386-
return ([45], lambda:())
388+
return_thunk([45])
387389

388390
def core_wasm(args):
389391
nonlocal dtor_value
@@ -446,13 +448,16 @@ def core_wasm(args):
446448
Own(rt),
447449
Own(rt)
448450
])
449-
args = [
450-
42,
451-
43,
452-
44,
453-
13
454-
]
455-
got,post_return = canon_lift(opts, inst, core_wasm, ft, args)
451+
452+
def arg_thunk():
453+
return [ 42, 43, 44, 13 ]
454+
455+
got = None
456+
def return_thunk(results):
457+
nonlocal got
458+
got = results
459+
460+
canon_lift(opts, inst, core_wasm, ft, arg_thunk, return_thunk)
456461

457462
assert(len(got) == 3)
458463
assert(got[0] == 46)

0 commit comments

Comments
 (0)