Skip to content

Commit 8bf8625

Browse files
authored
Take loop exit blocks into account in constructing edges for top sort when dealing with a loop header (#908)
Fixes #748.
1 parent fae9750 commit 8bf8625

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

ir/function.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ void BasicBlock::delInstr(const Instr *i) {
5959
}
6060
}
6161

62+
void BasicBlock::addExitBlock(BasicBlock* bb) {
63+
exit_blocks.emplace(bb);
64+
}
65+
6266
vector<Phi*> BasicBlock::phis() const {
6367
vector<Phi*> phis;
6468
for (auto &i : m_instrs) {
@@ -358,6 +362,17 @@ static vector<BasicBlock*> top_sort(const vector<BasicBlock*> &bbs) {
358362
if (dst_I != bb_map.end())
359363
edges[i].emplace(dst_I->second);
360364
}
365+
366+
// If `bb` is a loop header, we need to go through its exit block
367+
// in order to account for some transitive dependencies we may have
368+
// missed due to compression of its inner loops.
369+
// If there are no inner loops, this is redundant and if `bb` is not
370+
// a loop header, the set of its exit blocks is empty.
371+
for (auto &dst : bb->getExitBlocks()) {
372+
auto dst_I = bb_map.find(dst);
373+
if (dst_I != bb_map.end())
374+
edges[i].emplace(dst_I->second);
375+
}
361376
++i;
362377
}
363378

@@ -548,6 +563,7 @@ void Function::unroll(unsigned k) {
548563
for (auto &dst : bb->targets()) {
549564
if (!bbmap.count(&dst)) {
550565
exit_edges.emplace(bb, const_cast<BasicBlock*>(&dst));
566+
header->addExitBlock(const_cast<BasicBlock*>(&dst));
551567
}
552568
}
553569
}

ir/function.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <string_view>
1616
#include <tuple>
1717
#include <unordered_map>
18+
#include <unordered_set>
1819
#include <vector>
1920

2021
namespace smt { class Model; }
@@ -27,6 +28,9 @@ class BasicBlock final {
2728
std::string name;
2829
std::vector<std::unique_ptr<Instr>> m_instrs;
2930

31+
// If the basic block is a header, this holds all exit blocks of its loop
32+
std::unordered_set<BasicBlock*> exit_blocks;
33+
3034
public:
3135
BasicBlock(std::string_view name) : name(name) {}
3236

@@ -39,6 +43,11 @@ class BasicBlock final {
3943
void addInstrAt(std::unique_ptr<Instr> &&i, const Instr *other, bool before);
4044
void delInstr(const Instr *i);
4145

46+
void addExitBlock(BasicBlock* bb);
47+
const std::unordered_set<BasicBlock*>& getExitBlocks() const {
48+
return exit_blocks;
49+
}
50+
4251
util::const_strip_unique_ptr<decltype(m_instrs)> instrs() const {
4352
return m_instrs;
4453
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
; TEST-ARGS: -src-unroll=2 -tgt-unroll=2
2+
; llvm/llvm/test/Transforms/LoopSimplifyCFG/constant-fold-branch.ll
3+
4+
define i32 @src(i32 %end) {
5+
entry:
6+
br label %outer_header
7+
8+
outer_header:
9+
%j = phi i32 [0, %entry], [%j.inc, %outer_backedge]
10+
br label %preheader
11+
12+
preheader:
13+
br label %header
14+
15+
header:
16+
%i = phi i32 [0, %preheader], [%i.inc, %backedge]
17+
br label %backedge
18+
19+
backedge:
20+
%i.inc = add i32 %i, 1
21+
%cmp = icmp slt i32 %i.inc, %end
22+
br i1 %cmp, label %header, label %outer_backedge
23+
24+
outer_backedge:
25+
%j.inc = add i32 %j, 1
26+
%cmp.j = icmp slt i32 %j.inc, %end
27+
br i1 %cmp.j, label %outer_header, label %exit
28+
29+
exit:
30+
ret i32 %i.inc
31+
}
32+
33+
define i32 @tgt(i32 %end) {
34+
entry:
35+
br label %outer_header
36+
37+
outer_header: ; preds = %outer_backedge, %entry
38+
%j = phi i32 [ 0, %entry ], [ %j.inc, %outer_backedge ]
39+
br label %preheader
40+
41+
preheader: ; preds = %outer_header
42+
br label %header
43+
44+
header: ; preds = %backedge, %preheader
45+
%i = phi i32 [ 0, %preheader ], [ %i.inc, %backedge ]
46+
br label %backedge
47+
48+
backedge: ; preds = %header
49+
%i.inc = add i32 %i, 1
50+
%cmp = icmp slt i32 %i.inc, %end
51+
br i1 %cmp, label %header, label %outer_backedge
52+
53+
outer_backedge: ; preds = %backedge
54+
%i.inc.lcssa = phi i32 [ %i.inc, %backedge ]
55+
%j.inc = add i32 %j, 1
56+
%cmp.j = icmp slt i32 %j.inc, %end
57+
br i1 %cmp.j, label %outer_header, label %exit
58+
59+
exit: ; preds = %outer_backedge
60+
%i.inc.lcssa.lcssa = phi i32 [ %i.inc.lcssa, %outer_backedge ]
61+
ret i32 %i.inc.lcssa.lcssa
62+
}

0 commit comments

Comments
 (0)