|
1 | 1 |
|
| 2 | +#include "cfg/cfg-traversal.h" |
2 | 3 | #include "ir/properties.h"
|
3 | 4 | #include "ir/utils.h"
|
4 | 5 | #include "pass.h"
|
5 | 6 | #include "wasm-traversal.h"
|
6 | 7 | #include "wasm.h"
|
| 8 | +#include <algorithm> |
| 9 | +#include <cassert> |
| 10 | +#include <iostream> |
| 11 | +#include <optional> |
7 | 12 | #include <stack>
|
8 | 13 | #include <vector>
|
9 | 14 |
|
10 | 15 | namespace wasm {
|
11 | 16 |
|
12 | 17 | namespace {
|
13 | 18 |
|
14 |
| -struct Finder : TryDepthWalker<Finder> { |
15 |
| - explicit Finder(const PassOptions& passOptions) |
16 |
| - : TryDepthWalker<Finder>(), passOptions(passOptions) {} |
17 |
| - const PassOptions& passOptions; |
| 19 | +struct Info { |
| 20 | + bool isStartWithReturn = false; |
| 21 | + bool isInsideTryBlock = false; |
| 22 | + Expression* lastExpr = nullptr; |
| 23 | +}; |
| 24 | + |
| 25 | +struct NonReturnFinder |
| 26 | + : public CFGWalker<NonReturnFinder, |
| 27 | + UnifiedExpressionVisitor<NonReturnFinder>, |
| 28 | + Info> { |
| 29 | + using S = |
| 30 | + CFGWalker<NonReturnFinder, UnifiedExpressionVisitor<NonReturnFinder>, Info>; |
| 31 | + |
18 | 32 | std::vector<Call*> tailCalls;
|
19 | 33 | std::vector<CallIndirect*> tailCallIndirects;
|
20 |
| - void visitFunction(Function* curr) { |
21 |
| - if (passOptions.shrinkLevel > 0 && passOptions.optimizeLevel == 0) { |
22 |
| - // When we more force on the binary size, add return_call will increase |
23 |
| - // the code size. |
| 34 | + |
| 35 | + void visitExpression(Expression* curr) { |
| 36 | + if (currBasicBlock == nullptr) { |
24 | 37 | return;
|
25 | 38 | }
|
26 |
| - checkTailCall(curr->body); |
| 39 | + if (curr->is<Block>() || curr->is<If>() || curr->is<Loop>()) { |
| 40 | + // skip all control flow instructions |
| 41 | + return; |
| 42 | + } |
| 43 | + |
| 44 | + Expression* const lastExpr = currBasicBlock->contents.lastExpr; |
| 45 | + currBasicBlock->contents.lastExpr = curr; |
| 46 | + |
| 47 | + if (!tryStack.empty()) { |
| 48 | + // skip all try stack |
| 49 | + currBasicBlock->contents.isInsideTryBlock = true; |
| 50 | + } |
| 51 | + if (curr->is<Return>()) { |
| 52 | + if (lastExpr == nullptr) { |
| 53 | + currBasicBlock->contents.isStartWithReturn = true; |
| 54 | + } else { |
| 55 | + pushPotentialTailCall(lastExpr); |
| 56 | + } |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + void pushPotentialTailCall(Expression* curr) { |
| 61 | + if (curr) { |
| 62 | + if (curr->is<Call>()) { |
| 63 | + tailCalls.push_back(curr->cast<Call>()); |
| 64 | + } else if (curr->is<CallIndirect>()) { |
| 65 | + tailCallIndirects.push_back(curr->cast<CallIndirect>()); |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + void doWalkFunction(Function* func) { |
| 71 | + S::doWalkFunction(func); |
| 72 | + if (hasSyntheticExit && exit != nullptr) { |
| 73 | + exit->contents.isStartWithReturn = true; |
| 74 | + } |
| 75 | + if (exit != nullptr) { |
| 76 | + assert(tryStack.empty()); |
| 77 | + pushPotentialTailCall(exit->contents.lastExpr); |
| 78 | + } |
| 79 | + // propagate start with return flag |
| 80 | + bool hasUpdated = true; |
| 81 | + while (hasUpdated) { |
| 82 | + hasUpdated = false; |
| 83 | + for (std::unique_ptr<BasicBlock> const& bb : basicBlocks) { |
| 84 | + if (bb->contents.isStartWithReturn) { |
| 85 | + continue; |
| 86 | + } |
| 87 | + if (bb->contents.lastExpr == nullptr) { |
| 88 | + const bool followBasicBlockStartWithReturn = |
| 89 | + std::all_of(bb->out.begin(), bb->out.end(), [](BasicBlock* b) { |
| 90 | + return b->contents.isStartWithReturn; |
| 91 | + }); |
| 92 | + if (followBasicBlockStartWithReturn) { |
| 93 | + bb->contents.isStartWithReturn = true; |
| 94 | + hasUpdated = true; |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + for (std::unique_ptr<BasicBlock> const& bb : basicBlocks) { |
| 100 | + Expression* const lastExpr = bb->contents.lastExpr; |
| 101 | + if (lastExpr == nullptr) { |
| 102 | + continue; |
| 103 | + } |
| 104 | + const bool followBasicBlockStartWithReturn = |
| 105 | + std::all_of(bb->out.begin(), bb->out.end(), [](BasicBlock* b) { |
| 106 | + return b->contents.isStartWithReturn; |
| 107 | + }); |
| 108 | + if (!followBasicBlockStartWithReturn) { |
| 109 | + continue; |
| 110 | + } |
| 111 | + pushPotentialTailCall(lastExpr); |
| 112 | + } |
27 | 113 | }
|
| 114 | +}; |
| 115 | + |
| 116 | +struct ReturnFinder : TryDepthWalker<ReturnFinder> { |
| 117 | + explicit ReturnFinder(const PassOptions& passOptions) |
| 118 | + : TryDepthWalker<ReturnFinder>(), passOptions(passOptions) {} |
| 119 | + const PassOptions& passOptions; |
| 120 | + std::vector<Call*> tailCalls; |
| 121 | + std::vector<CallIndirect*> tailCallIndirects; |
| 122 | + void visitFunction(Function* curr) { checkTailCall(curr->body); } |
28 | 123 | void visitReturn(Return* curr) {
|
29 | 124 | if (tryDepth > 0) {
|
30 | 125 | // (return (call ...)) is not equal to (return_call ...) in try block
|
@@ -77,22 +172,39 @@ struct TailCallOptimizer : public Pass {
|
77 | 172 | std::unique_ptr<Pass> create() override {
|
78 | 173 | return std::make_unique<TailCallOptimizer>();
|
79 | 174 | }
|
80 |
| - void runOnFunction(Module* module, Function* function) override { |
81 |
| - if (!module->features.hasTailCall()) { |
82 |
| - return; |
83 |
| - } |
84 |
| - Finder finder{getPassOptions()}; |
85 |
| - finder.walkFunctionInModule(function, module); |
86 |
| - for (Call* call : finder.tailCalls) { |
| 175 | + |
| 176 | + static void modify(std::vector<Call*> const& tailCalls, |
| 177 | + std::vector<CallIndirect*> const& tailCallIndirects) { |
| 178 | + for (Call* call : tailCalls) { |
87 | 179 | if (!call->isReturn) {
|
88 | 180 | call->isReturn = true;
|
89 | 181 | }
|
90 | 182 | }
|
91 |
| - for (CallIndirect* call : finder.tailCallIndirects) { |
| 183 | + for (CallIndirect* call : tailCallIndirects) { |
92 | 184 | if (!call->isReturn) {
|
93 | 185 | call->isReturn = true;
|
94 | 186 | }
|
95 | 187 | }
|
| 188 | + } |
| 189 | + void runOnFunction(Module* module, Function* function) override { |
| 190 | + if (!module->features.hasTailCall()) { |
| 191 | + return; |
| 192 | + } |
| 193 | + if (getPassOptions().shrinkLevel > 0 && |
| 194 | + getPassOptions().optimizeLevel == 0) { |
| 195 | + // When we more force on the binary size, add return_call will increase |
| 196 | + // the code size. |
| 197 | + return; |
| 198 | + } |
| 199 | + if (function->getResults().size() == 0) { |
| 200 | + NonReturnFinder finder{}; |
| 201 | + finder.walkFunctionInModule(function, module); |
| 202 | + modify(finder.tailCalls, finder.tailCallIndirects); |
| 203 | + } else { |
| 204 | + ReturnFinder finder{getPassOptions()}; |
| 205 | + finder.walkFunctionInModule(function, module); |
| 206 | + modify(finder.tailCalls, finder.tailCallIndirects); |
| 207 | + } |
96 | 208 | ReFinalize{}.walkFunctionInModule(function, module);
|
97 | 209 | }
|
98 | 210 | };
|
|
0 commit comments