Skip to content

Commit 1d6320a

Browse files
authored
[Stack Switching] Update continuation types based on resume tag & target block (#7786)
The resume tag defines how the continuation should be called, perhaps with different parameters than before. To handle that, update its type and also properly handle function parameters (if we are resuming with a different type, we should not error - the function params are already saved as local state, and we just need to reload them). The actual type also depends on the target block. This allows the next part of the spec test to run.
1 parent b28ab39 commit 1d6320a

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

src/wasm-interpreter.h

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3458,14 +3458,20 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
34583458
parent.scope = this;
34593459
parent.callDepth++;
34603460
parent.functionStack.push_back(function->name);
3461+
locals.resize(function->getNumLocals());
3462+
3463+
if (parent.resuming) {
3464+
// Nothing more to do here: we are resuming execution, so there is old
3465+
// locals state that will be restored.
3466+
return;
3467+
}
34613468

34623469
if (function->getParams().size() != arguments.size()) {
34633470
std::cerr << "Function `" << function->name << "` expects "
34643471
<< function->getParams().size() << " parameters, got "
34653472
<< arguments.size() << " arguments." << std::endl;
34663473
WASM_UNREACHABLE("invalid param count");
34673474
}
3468-
locals.resize(function->getNumLocals());
34693475
Type params = function->getParams();
34703476
for (size_t i = 0; i < function->getNumLocals(); i++) {
34713477
if (i < arguments.size()) {
@@ -4697,6 +4703,10 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
46974703
assert(!old || old->executed);
46984704
auto new_ = std::make_shared<ContData>(old ? old->func : Name(),
46994705
old ? old->type : HeapType::none);
4706+
// Note we cannot update the type yet, so it will be wrong in debug
4707+
// logging. To update it, we must find the block that receives this value,
4708+
// which means we cannot do it here (we don't even know what that block is).
4709+
47004710
// Switch to the new continuation, so that as we unwind, we will save the
47014711
// information we need to resume it later in the proper place.
47024712
self()->currContinuation = new_;
@@ -4758,8 +4768,30 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
47584768
// Switch the flow from suspending to branching.
47594769
ret.suspendTag = Name();
47604770
ret.breakTo = curr->handlerBlocks[i];
4771+
// We can now update the continuation type, which was wrong until now
4772+
// (see comment in visitSuspend). The type is taken from the block we
4773+
// branch to (which we find in a quite inefficient manner).
4774+
struct BlockFinder : public PostWalker<BlockFinder> {
4775+
Name target;
4776+
Type type = Type::none;
4777+
void visitBlock(Block* curr) {
4778+
if (curr->name == target) {
4779+
type = curr->type;
4780+
}
4781+
}
4782+
} finder;
4783+
finder.target = ret.breakTo;
4784+
// We must be in a function scope.
4785+
assert(self()->scope->function);
4786+
finder.walk(self()->scope->function->body);
4787+
// We must have found the type, and it must be valid.
4788+
assert(finder.type.isConcrete());
4789+
assert(finder.type.size() >= 1);
4790+
// The continuation is the final value/type there.
4791+
auto cont = self()->currContinuation;
4792+
cont->type = finder.type[finder.type.size() - 1].getHeapType();
47614793
// Add the continuation as the final value being sent.
4762-
ret.values.push_back(Literal(self()->currContinuation));
4794+
ret.values.push_back(Literal(cont));
47634795
// We are not longer processing that continuation.
47644796
self()->currContinuation.reset();
47654797
return ret;

test/spec/cont.wast

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,58 @@
279279

280280
(assert_return (invoke "run") (i32.const 19))
281281

282+
;; Simple generator example
283+
284+
(module $generator
285+
(type $gen (func (param i64)))
286+
(type $geny (func (param i32)))
287+
(type $cont0 (cont $gen))
288+
(type $cont (cont $geny))
289+
290+
(tag $yield (param i64) (result i32))
291+
292+
;; Hook for logging purposes
293+
(global $hook (export "hook") (mut (ref $gen)) (ref.func $dummy))
294+
(func $dummy (param i64))
295+
296+
(func $gen (export "start") (param $i i64)
297+
(loop $l
298+
(br_if 1 (suspend $yield (local.get $i)))
299+
(call_ref $gen (local.get $i) (global.get $hook))
300+
(local.set $i (i64.add (local.get $i) (i64.const 1)))
301+
(br $l)
302+
)
303+
)
304+
305+
(elem declare func $gen)
306+
307+
(func (export "sum") (param $i i64) (param $j i64) (result i64)
308+
(local $sum i64)
309+
(local $n i64)
310+
(local $k (ref null $cont))
311+
(local.get $i)
312+
(cont.new $cont0 (ref.func $gen))
313+
(block $on_first_yield (param i64 (ref $cont0)) (result i64 (ref $cont))
314+
(resume $cont0 (on $yield $on_first_yield))
315+
(unreachable)
316+
)
317+
(loop $on_yield (param i64) (param (ref $cont))
318+
(local.set $k)
319+
(local.set $n)
320+
(local.set $sum (i64.add (local.get $sum) (local.get $n)))
321+
(i64.eq (local.get $n) (local.get $j))
322+
(local.get $k)
323+
(resume $cont (on $yield $on_yield))
324+
)
325+
(return (local.get $sum))
326+
)
327+
)
328+
329+
(register "generator")
330+
331+
(assert_return (invoke "sum" (i64.const 0) (i64.const 0)) (i64.const 0))
332+
(assert_return (invoke "sum" (i64.const 2) (i64.const 2)) (i64.const 2))
333+
(assert_return (invoke "sum" (i64.const 0) (i64.const 3)) (i64.const 6))
334+
(assert_return (invoke "sum" (i64.const 1) (i64.const 10)) (i64.const 55))
335+
(assert_return (invoke "sum" (i64.const 100) (i64.const 2000)) (i64.const 1_996_050))
336+

0 commit comments

Comments
 (0)