Skip to content

Commit cd0a99e

Browse files
committed
feat: implement tail call optimization
1 parent 8c82b68 commit cd0a99e

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

src/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ set(passes_SOURCES
114114
ReorderGlobals.cpp
115115
ReorderLocals.cpp
116116
ReReloop.cpp
117+
TailCall.cpp
117118
TrapMode.cpp
118119
TypeGeneralizing.cpp
119120
TypeRefining.cpp

src/passes/TailCall.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
#include "pass.h"
3+
#include "wasm-traversal.h"
4+
#include "wasm.h"
5+
#include <vector>
6+
7+
namespace wasm {
8+
9+
namespace {
10+
11+
struct Finder : PostWalker<Finder> {
12+
std::vector<Call*> tailCalls;
13+
std::vector<CallIndirect*> tailCallIndirects;
14+
void visitFunction(Function* curr) { checkTailCall(curr->body); }
15+
void visitReturn(Return* curr) { checkTailCall(curr->value); }
16+
17+
private:
18+
void checkTailCall(Expression* expr) {
19+
if (expr == nullptr) {
20+
return;
21+
}
22+
if (auto* call = expr->dynCast<Call>()) {
23+
if (!call->isReturn && call->type == getFunction()->getResults()) {
24+
tailCalls.push_back(call);
25+
}
26+
return;
27+
}
28+
if (auto* call = expr->dynCast<CallIndirect>()) {
29+
if (!call->isReturn && call->type == getFunction()->getResults()) {
30+
tailCallIndirects.push_back(call);
31+
}
32+
return;
33+
}
34+
if (auto* block = expr->dynCast<Block>()) {
35+
return checkTailCall(block->list);
36+
}
37+
if (auto* ifElse = expr->dynCast<If>()) {
38+
checkTailCall(ifElse->ifTrue);
39+
checkTailCall(ifElse->ifFalse);
40+
return;
41+
}
42+
}
43+
void checkTailCall(ExpressionList const& exprs) {
44+
if (exprs.empty()) {
45+
return;
46+
}
47+
checkTailCall(exprs.back());
48+
return;
49+
}
50+
};
51+
52+
} // namespace
53+
54+
struct TailCallOptimizer : public Pass {
55+
bool isFunctionParallel() override { return true; }
56+
std::unique_ptr<Pass> create() override {
57+
return std::make_unique<TailCallOptimizer>();
58+
}
59+
void runOnFunction(Module* module, Function* function) override {
60+
if (!module->features.hasTailCall()) {
61+
return;
62+
}
63+
Finder finder{};
64+
finder.walkFunctionInModule(function, module);
65+
for (Call* call : finder.tailCalls) {
66+
if (!call->isReturn) {
67+
call->isReturn = true;
68+
call->finalize();
69+
}
70+
}
71+
for (CallIndirect* call : finder.tailCallIndirects) {
72+
if (!call->isReturn) {
73+
call->isReturn = true;
74+
call->finalize();
75+
}
76+
}
77+
}
78+
};
79+
80+
Pass* createTailCallPass() { return new TailCallOptimizer(); }
81+
82+
} // namespace wasm

src/passes/pass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ void PassRegistry::registerPasses() {
552552
registerPass("strip-target-features",
553553
"strip the wasm target features section",
554554
createStripTargetFeaturesPass);
555+
registerPass(
556+
"tail-call", "transform call to return call", createTailCallPass);
555557
registerPass("translate-to-new-eh",
556558
"deprecated; same as translate-to-exnref",
557559
createTranslateToExnrefPass);

src/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ Pass* createStripEHPass();
176176
Pass* createStubUnsupportedJSOpsPass();
177177
Pass* createSSAifyPass();
178178
Pass* createSSAifyNoMergePass();
179+
Pass* createTailCallPass();
179180
Pass* createTable64LoweringPass();
180181
Pass* createTranslateToExnrefPass();
181182
Pass* createTrapModeClamp();

0 commit comments

Comments
 (0)