Skip to content

Commit 3be44c1

Browse files
authored
Properly optimize loop values (#1800)
Remove the existing hack, and optimize them just like we do for ifs and blocks. This is now able to handle a few more cases than before.
1 parent 26ef829 commit 3be44c1

16 files changed

+5997
-5894
lines changed

src/passes/SimplifyLocals.cpp

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals<a
195195

196196
void visitLoop(Loop* curr) {
197197
if (allowStructure) {
198-
if (canUseLoopReturnValue(curr)) {
199-
loops.push_back(this->getCurrentPointer());
200-
}
198+
optimizeLoopReturn(curr);
201199
}
202200
}
203201

@@ -335,7 +333,37 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals<a
335333

336334
std::vector<Block*> blocksToEnlarge;
337335
std::vector<If*> ifsToEnlarge;
338-
std::vector<Expression**> loops;
336+
std::vector<Loop*> loopsToEnlarge;
337+
338+
void optimizeLoopReturn(Loop* loop) {
339+
// If there is a sinkable thing in an eligible loop, we can optimize
340+
// it in a trivial way to the outside of the loop.
341+
if (loop->type != none) return;
342+
if (sinkables.empty()) return;
343+
Index goodIndex = sinkables.begin()->first;
344+
// Ensure we have a place to write the return values for, if not, we
345+
// need another cycle.
346+
auto* block = loop->body->dynCast<Block>();
347+
if (!block || block->name.is() || block->list.size() == 0 || !block->list.back()->is<Nop>()) {
348+
loopsToEnlarge.push_back(loop);
349+
return;
350+
}
351+
Builder builder(*this->getModule());
352+
auto** item = sinkables.at(goodIndex).item;
353+
auto* set = (*item)->template cast<SetLocal>();
354+
block->list[block->list.size() - 1] = set->value;
355+
*item = builder.makeNop();
356+
block->finalize();
357+
assert(block->type != none);
358+
loop->finalize();
359+
set->value = loop;
360+
set->finalize();
361+
this->replaceCurrent(set);
362+
// We moved things around, clear all tracking; we'll do another cycle
363+
// anyhow.
364+
sinkables.clear();
365+
anotherCycle = true;
366+
}
339367

340368
void optimizeBlockReturn(Block* block) {
341369
if (!block->name.is() || unoptimizableBlocks.count(block->name) > 0) {
@@ -685,23 +713,18 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals<a
685713
ifsToEnlarge.clear();
686714
anotherCycle = true;
687715
}
688-
// handle loops. note that a lot happens in this pass, and we can't just modify
689-
// set_locals when we see a loop - it might be tracked - and we can't also just
690-
// assume our loop didn't move either (might be in a block now). So we do this
691-
// when all other work is done - as loop return values are rare, that is fine.
692-
if (!anotherCycle) {
693-
for (auto* currp : loops) {
694-
auto* curr = (*currp)->template cast<Loop>();
695-
assert(canUseLoopReturnValue(curr));
696-
auto* set = curr->body->template cast<SetLocal>();
697-
curr->body = set->value;
698-
set->value = curr;
699-
curr->finalize(curr->body->type);
700-
*currp = set;
701-
anotherCycle = true;
716+
// enlarge loops that were marked, for the next round
717+
if (loopsToEnlarge.size() > 0) {
718+
for (auto* loop : loopsToEnlarge) {
719+
auto block = Builder(*this->getModule()).blockifyWithName(loop->body, Name());
720+
loop->body = block;
721+
if (block->list.size() == 0 || !block->list.back()->template is<Nop>()) {
722+
block->list.push_back(this->getModule()->allocator.template alloc<Nop>());
723+
}
702724
}
725+
loopsToEnlarge.clear();
726+
anotherCycle = true;
703727
}
704-
loops.clear();
705728
// clean up
706729
sinkables.clear();
707730
blockBreaks.clear();
@@ -847,17 +870,6 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals<a
847870

848871
return eqOpter.anotherCycle || setRemover.anotherCycle;
849872
}
850-
851-
bool canUseLoopReturnValue(Loop* curr) {
852-
// Optimizing a loop return value is trivial: just see if it contains
853-
// a set_local, and pull that out.
854-
if (auto* set = curr->body->template dynCast<SetLocal>()) {
855-
if (isConcreteType(set->value->type)) {
856-
return true;
857-
}
858-
}
859-
return false;
860-
}
861873
};
862874

863875
Pass *createSimplifyLocalsPass() {

0 commit comments

Comments
 (0)