Skip to content

Commit 2a01572

Browse files
committed
implement for w/o return function
1 parent d48eab8 commit 2a01572

File tree

2 files changed

+174
-17
lines changed

2 files changed

+174
-17
lines changed

src/passes/TailCall.cpp

Lines changed: 129 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,125 @@
11

2+
#include "cfg/cfg-traversal.h"
23
#include "ir/properties.h"
34
#include "ir/utils.h"
45
#include "pass.h"
56
#include "wasm-traversal.h"
67
#include "wasm.h"
8+
#include <algorithm>
9+
#include <cassert>
10+
#include <iostream>
11+
#include <optional>
712
#include <stack>
813
#include <vector>
914

1015
namespace wasm {
1116

1217
namespace {
1318

14-
struct Finder : TryDepthWalker<Finder> {
15-
explicit Finder(const PassOptions& passOptions)
16-
: TryDepthWalker<Finder>(), passOptions(passOptions) {}
17-
const PassOptions& passOptions;
19+
struct Info {
20+
bool isStartWithReturn = false;
21+
bool isInsideTryBlock = false;
22+
Expression* lastExpr = nullptr;
23+
};
24+
25+
struct NonReturnFinder
26+
: public CFGWalker<NonReturnFinder,
27+
UnifiedExpressionVisitor<NonReturnFinder>,
28+
Info> {
29+
using S =
30+
CFGWalker<NonReturnFinder, UnifiedExpressionVisitor<NonReturnFinder>, Info>;
31+
1832
std::vector<Call*> tailCalls;
1933
std::vector<CallIndirect*> tailCallIndirects;
20-
void visitFunction(Function* curr) {
21-
if (passOptions.shrinkLevel > 0 && passOptions.optimizeLevel == 0) {
22-
// When we more force on the binary size, add return_call will increase
23-
// the code size.
34+
35+
void visitExpression(Expression* curr) {
36+
if (currBasicBlock == nullptr) {
2437
return;
2538
}
26-
checkTailCall(curr->body);
39+
if (curr->is<Block>() || curr->is<If>() || curr->is<Loop>()) {
40+
// skip all control flow instructions
41+
return;
42+
}
43+
44+
Expression* const lastExpr = currBasicBlock->contents.lastExpr;
45+
currBasicBlock->contents.lastExpr = curr;
46+
47+
if (!tryStack.empty()) {
48+
// skip all try stack
49+
currBasicBlock->contents.isInsideTryBlock = true;
50+
}
51+
if (curr->is<Return>()) {
52+
if (lastExpr == nullptr) {
53+
currBasicBlock->contents.isStartWithReturn = true;
54+
} else {
55+
pushPotentialTailCall(lastExpr);
56+
}
57+
}
58+
}
59+
60+
void pushPotentialTailCall(Expression* curr) {
61+
if (curr) {
62+
if (curr->is<Call>()) {
63+
tailCalls.push_back(curr->cast<Call>());
64+
} else if (curr->is<CallIndirect>()) {
65+
tailCallIndirects.push_back(curr->cast<CallIndirect>());
66+
}
67+
}
68+
}
69+
70+
void doWalkFunction(Function* func) {
71+
S::doWalkFunction(func);
72+
if (hasSyntheticExit && exit != nullptr) {
73+
exit->contents.isStartWithReturn = true;
74+
}
75+
if (exit != nullptr) {
76+
assert(tryStack.empty());
77+
pushPotentialTailCall(exit->contents.lastExpr);
78+
}
79+
// propagate start with return flag
80+
bool hasUpdated = true;
81+
while (hasUpdated) {
82+
hasUpdated = false;
83+
for (std::unique_ptr<BasicBlock> const& bb : basicBlocks) {
84+
if (bb->contents.isStartWithReturn) {
85+
continue;
86+
}
87+
if (bb->contents.lastExpr == nullptr) {
88+
const bool followBasicBlockStartWithReturn =
89+
std::all_of(bb->out.begin(), bb->out.end(), [](BasicBlock* b) {
90+
return b->contents.isStartWithReturn;
91+
});
92+
if (followBasicBlockStartWithReturn) {
93+
bb->contents.isStartWithReturn = true;
94+
hasUpdated = true;
95+
}
96+
}
97+
}
98+
}
99+
for (std::unique_ptr<BasicBlock> const& bb : basicBlocks) {
100+
Expression* const lastExpr = bb->contents.lastExpr;
101+
if (lastExpr == nullptr) {
102+
continue;
103+
}
104+
const bool followBasicBlockStartWithReturn =
105+
std::all_of(bb->out.begin(), bb->out.end(), [](BasicBlock* b) {
106+
return b->contents.isStartWithReturn;
107+
});
108+
if (!followBasicBlockStartWithReturn) {
109+
continue;
110+
}
111+
pushPotentialTailCall(lastExpr);
112+
}
27113
}
114+
};
115+
116+
struct ReturnFinder : TryDepthWalker<ReturnFinder> {
117+
explicit ReturnFinder(const PassOptions& passOptions)
118+
: TryDepthWalker<ReturnFinder>(), passOptions(passOptions) {}
119+
const PassOptions& passOptions;
120+
std::vector<Call*> tailCalls;
121+
std::vector<CallIndirect*> tailCallIndirects;
122+
void visitFunction(Function* curr) { checkTailCall(curr->body); }
28123
void visitReturn(Return* curr) {
29124
if (tryDepth > 0) {
30125
// (return (call ...)) is not equal to (return_call ...) in try block
@@ -77,22 +172,39 @@ struct TailCallOptimizer : public Pass {
77172
std::unique_ptr<Pass> create() override {
78173
return std::make_unique<TailCallOptimizer>();
79174
}
80-
void runOnFunction(Module* module, Function* function) override {
81-
if (!module->features.hasTailCall()) {
82-
return;
83-
}
84-
Finder finder{getPassOptions()};
85-
finder.walkFunctionInModule(function, module);
86-
for (Call* call : finder.tailCalls) {
175+
176+
static void modify(std::vector<Call*> const& tailCalls,
177+
std::vector<CallIndirect*> const& tailCallIndirects) {
178+
for (Call* call : tailCalls) {
87179
if (!call->isReturn) {
88180
call->isReturn = true;
89181
}
90182
}
91-
for (CallIndirect* call : finder.tailCallIndirects) {
183+
for (CallIndirect* call : tailCallIndirects) {
92184
if (!call->isReturn) {
93185
call->isReturn = true;
94186
}
95187
}
188+
}
189+
void runOnFunction(Module* module, Function* function) override {
190+
if (!module->features.hasTailCall()) {
191+
return;
192+
}
193+
if (getPassOptions().shrinkLevel > 0 &&
194+
getPassOptions().optimizeLevel == 0) {
195+
// When we more force on the binary size, add return_call will increase
196+
// the code size.
197+
return;
198+
}
199+
if (function->getResults().size() == 0) {
200+
NonReturnFinder finder{};
201+
finder.walkFunctionInModule(function, module);
202+
modify(finder.tailCalls, finder.tailCallIndirects);
203+
} else {
204+
ReturnFinder finder{getPassOptions()};
205+
finder.walkFunctionInModule(function, module);
206+
modify(finder.tailCalls, finder.tailCallIndirects);
207+
}
96208
ReFinalize{}.walkFunctionInModule(function, module);
97209
}
98210
};

test/lit/tail-call-optimization.wast

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
;; CHECK: (type $1 (func (param i32) (result i32)))
1111

12+
;; CHECK: (type $2 (func (param i32)))
13+
14+
;; CHECK: (type $3 (func))
15+
1216
;; CHECK: (func $f (result i32)
1317
;; CHECK-NEXT: (i32.const 0)
1418
;; CHECK-NEXT: )
@@ -128,6 +132,47 @@
128132
br_if 0
129133
end
130134
)
135+
136+
;; CHECK: (func $g (param $0 i32)
137+
;; CHECK-NEXT: )
138+
(func $g (param i32))
139+
;; CHECK: (func $return_without_value
140+
;; CHECK-NEXT: (return_call $g
141+
;; CHECK-NEXT: (i32.const 1)
142+
;; CHECK-NEXT: )
143+
;; CHECK-NEXT: )
144+
(func $return_without_value
145+
i32.const 1
146+
call $g
147+
)
148+
;; CHECK: (func $return_without_value_through_block (param $0 i32)
149+
;; CHECK-NEXT: (return_call $g
150+
;; CHECK-NEXT: (local.get $0)
151+
;; CHECK-NEXT: )
152+
;; CHECK-NEXT: )
153+
(func $return_without_value_through_block (param $0 i32)
154+
block
155+
local.get 0
156+
call $g
157+
end
158+
)
159+
;; CHECK: (func $return_without_value_through_if (param $0 i32)
160+
;; CHECK-NEXT: (if
161+
;; CHECK-NEXT: (local.get $0)
162+
;; CHECK-NEXT: (then
163+
;; CHECK-NEXT: (return_call $g
164+
;; CHECK-NEXT: (local.get $0)
165+
;; CHECK-NEXT: )
166+
;; CHECK-NEXT: )
167+
;; CHECK-NEXT: )
168+
;; CHECK-NEXT: )
169+
(func $return_without_value_through_if (param $0 i32)
170+
local.get 0
171+
if
172+
local.get 0
173+
call $g
174+
end
175+
)
131176
)
132177

133178
(module $NYI

0 commit comments

Comments
 (0)