Skip to content

Commit 85146eb

Browse files
authored
Split continues with convergent calls in some cases (#1439)
* Adds a new transform to FixupStructuredCFG to split a continue (latch) into two blocks under certain circumstances: * The loop header is a conditional branch to the body and latch * The latch has two predecessors * The latch contains a convergent call * This transformation prevents clspv forces (along with the breakConditionalHeader transform in the same pass) to prevent convergent operations from being placed in the loop continue contruct. Instead they end up as a structured selection in the body. This ensures reconvergence more robustly than previously. SPIRV-Cross, for example, inlines continues into the body under the assumption that reconvergence is not expected
1 parent 7b4ea7e commit 85146eb

File tree

5 files changed

+210
-5
lines changed

5 files changed

+210
-5
lines changed

lib/FixupStructuredCFGPass.cpp

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ PreservedAnalyses
2626
clspv::FixupStructuredCFGPass::run(Function &F, FunctionAnalysisManager &FAM) {
2727
// Assumes CFG has been structurized.
2828
isolateContinue(F, FAM);
29-
// Run after isolateContinue since this can invalidate loop info.
29+
isolateConvergentLatch(F, FAM);
3030
breakConditionalHeader(F, FAM);
3131

3232
removeUndefPHI(F);
@@ -75,6 +75,97 @@ void clspv::FixupStructuredCFGPass::removeUndefPHI(Function &F) {
7575
}
7676
}
7777

78+
void clspv::FixupStructuredCFGPass::isolateConvergentLatch(
79+
Function &F, FunctionAnalysisManager &FAM) {
80+
auto &LI = FAM.getResult<LoopAnalysis>(F);
81+
82+
std::vector<BasicBlock *> blocks;
83+
blocks.reserve(F.size());
84+
for (auto &BB : F) {
85+
blocks.push_back(&BB);
86+
}
87+
88+
for (auto *BB : blocks) {
89+
if (!LI.isLoopHeader(BB))
90+
continue;
91+
92+
auto *loop = LI.getLoopFor(BB);
93+
auto *latch = loop->getLoopLatch();
94+
// Skip single block loops.
95+
if (!latch || latch == BB) {
96+
continue;
97+
}
98+
99+
// Latch needs two predecessors.
100+
if (!latch->hasNPredecessors(2)) {
101+
continue;
102+
}
103+
104+
// Header is a conditional branch.
105+
auto header_terminator = dyn_cast_or_null<BranchInst>(BB->getTerminator());
106+
if (!header_terminator || !header_terminator->isConditional()) {
107+
continue;
108+
}
109+
110+
// One edge jumps to the continue target.
111+
if (header_terminator->getSuccessor(0) != latch &&
112+
header_terminator->getSuccessor(1) != latch) {
113+
continue;
114+
}
115+
116+
// The continue contains a convergent call.
117+
bool has_convergent_call = false;
118+
for (auto &inst : *latch) {
119+
if (auto *call = dyn_cast<CallInst>(&inst)) {
120+
if (call->hasFnAttr(Attribute::Convergent)) {
121+
has_convergent_call = true;
122+
break;
123+
}
124+
}
125+
}
126+
if (!has_convergent_call) {
127+
continue;
128+
}
129+
130+
auto *latch_terminator =
131+
dyn_cast_or_null<BranchInst>(latch->getTerminator());
132+
if (!latch_terminator)
133+
continue;
134+
135+
// Break the latch such that it is a single-entry single-exit block.
136+
// This will force later transforms in this fixup to break the loop header
137+
// which puts the whole loop body as a selection.
138+
if (latch_terminator->isConditional()) {
139+
// Safety valve: if this is not an exiting block then the loop is not
140+
// structured as expected.
141+
if (!loop->isLoopExiting(latch)) {
142+
continue;
143+
}
144+
145+
// Conditional branch case: one edge back to header and one out of the
146+
// loop. Transformed into one edge out of the loop and one edge to the new
147+
// continue and thence to the header.
148+
auto new_latch =
149+
BasicBlock::Create(F.getContext(), "", &F, latch->getNextNode());
150+
BranchInst::Create(BB, new_latch);
151+
loop->addBlockEntry(new_latch);
152+
153+
const auto idx = latch_terminator->getSuccessor(0) == BB ? 0 : 1;
154+
latch_terminator->setSuccessor(idx, new_latch);
155+
156+
// Update phis to use the new basic block.
157+
for (auto iter = BB->begin(); &*iter != BB->getFirstNonPHI(); ++iter) {
158+
PHINode *phi = cast<PHINode>(&*iter);
159+
phi->replaceIncomingBlockWith(latch, new_latch);
160+
}
161+
} else {
162+
// Simple case: just split the block.
163+
auto new_block = latch->splitBasicBlockBefore(latch_terminator);
164+
loop->addBlockEntry(new_block);
165+
}
166+
}
167+
}
168+
78169
void clspv::FixupStructuredCFGPass::breakConditionalHeader(
79170
Function &F, FunctionAnalysisManager &FAM) {
80171
auto &LI = FAM.getResult<LoopAnalysis>(F);
@@ -106,7 +197,8 @@ void clspv::FixupStructuredCFGPass::breakConditionalHeader(
106197
bool succ2_in_body = succ2 != latch && succ2 != exit;
107198

108199
if (succ1_in_body && succ2_in_body) {
109-
BB->splitBasicBlockBefore(terminator);
200+
auto new_block = BB->splitBasicBlockBefore(terminator);
201+
loop->addBlockEntry(new_block);
110202
}
111203
}
112204
}

lib/FixupStructuredCFGPass.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@ struct FixupStructuredCFGPass : llvm::PassInfoMixin<FixupStructuredCFGPass> {
2828
void breakConditionalHeader(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);
2929
void isolateContinue(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);
3030

31+
/**
32+
* Transforms a loop such as:
33+
*
34+
* header --\
35+
* / \ |
36+
* body | |
37+
* \ / ^
38+
* latch |
39+
* / \ |
40+
* exit ---/
41+
*
42+
* Into:
43+
* header --------\
44+
* / \ |
45+
* body | |
46+
* \ / ^
47+
* old_latch |
48+
* / \ |
49+
* exit new_latch -/
50+
*
51+
* When the latch contains a convergent call (e.g. a barrier). This will force
52+
* breakConditionalHeader to transform the loop also and effectively
53+
* encapsulates body within a selection now fully contained in the body of the
54+
* loop. This effectively moves the convergent call out of the latch where
55+
* SPIR-V does not guarantee reconvergence (without maximal reconvergence)
56+
* into a fully structured section where reconvergence is guaranteed.
57+
*/
58+
void isolateConvergentLatch(llvm::Function &F,
59+
llvm::FunctionAnalysisManager &FAM);
3160
};
3261
} // namespace clspv
3362

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: clspv-opt --passes=fixup-structured-cfg %s -o %t.ll
2+
; RUN: FileCheck %s < %t.ll
3+
4+
; CHECK: entry:
5+
; CHECK-NEXT: br label %[[new_header:[a-zA-Z0-9_.]+]]
6+
; CHECK: [[new_header]]:
7+
; CHECK-NEXT: br label %loop
8+
; CHECK: loop:
9+
; CHECK-NEXT: br i1 undef, label %then, label %[[pre_cont:[a-zA-Z0-9_.]+]]
10+
; CHECK: then:
11+
; CHECK-NEXT: br i1 undef, label %[[pre_cont]], label %exit
12+
; CHECK: [[pre_cont]]:
13+
; CHECK: call void @_Z8spirv.op.224
14+
; CHECK-NEXT: br label %[[cont:[a-zA-Z0-9_.]+]]
15+
; CHECK: [[cont]]:
16+
; CHECK-NEXT: br label %[[new_header]]
17+
18+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
19+
target triple = "spir-unknown-unknown"
20+
21+
define spir_kernel void @test() {
22+
entry:
23+
br label %loop
24+
25+
loop:
26+
br i1 undef, label %then, label %cont
27+
28+
then:
29+
br i1 undef, label %cont, label %exit
30+
31+
cont:
32+
tail call void @_Z8spirv.op.224.jjj(i32 224, i32 2, i32 2, i32 264) #0
33+
br label %loop
34+
35+
exit:
36+
ret void
37+
}
38+
39+
attributes #0 = { convergent }
40+
41+
declare void @_Z8spirv.op.224.jjj(i32, i32, i32, i32) #0
42+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; RUN: clspv-opt --passes=fixup-structured-cfg %s -o %t.ll
2+
; RUN: FileCheck %s < %t.ll
3+
4+
; CHECK: entry:
5+
; CHECK-NEXT: br label %[[new_header:[a-zA-Z0-9_.]+]]
6+
; CHECK: [[new_header]]:
7+
; CHECK-NEXT: phi i32 [ 0, %entry ], [ 1, %[[cont:[a-zA-Z0-9_.]+]] ]
8+
; CHECK-NEXT: br label %loop
9+
; CHECK: loop:
10+
; CHECK-NEXT: br i1 undef, label %then, label %[[pre_cont:[a-zA-Z0-9_.]+]]
11+
; CHECK: then:
12+
; CHECK-NEXT: br label %[[pre_cont]]
13+
; CHECK: [[pre_cont]]:
14+
; CHECK: call void @_Z8spirv.op.224
15+
; CHECK-NEXT: br i1 undef, label %[[cont]], label %exit
16+
; CHECK: [[cont]]:
17+
; CHECK-NEXT: br label %[[new_header]]
18+
19+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
20+
target triple = "spir-unknown-unknown"
21+
22+
define spir_kernel void @test() {
23+
entry:
24+
br label %loop
25+
26+
loop:
27+
%0 = phi i32 [ 0, %entry ], [ 1, %cont ]
28+
br i1 undef, label %then, label %cont
29+
30+
then:
31+
br label %cont
32+
33+
cont:
34+
tail call void @_Z8spirv.op.224.jjj(i32 224, i32 2, i32 2, i32 264) #0
35+
br i1 undef, label %loop, label %exit
36+
37+
exit:
38+
ret void
39+
}
40+
41+
attributes #0 = { convergent }
42+
43+
declare void @_Z8spirv.op.224.jjj(i32, i32, i32, i32) #0
44+

test/loop_continue_no_selection_merge.cl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
// CHECK: OpBranch [[CONT]]
1313
// CHECK: [[CONT]] = OpLabel
1414
// CHECK-NOT: OpLabel
15-
// CHECK: OpControlBarrier
16-
// CHECK-NOT: OpLabel
1715
// CHECK: OpBranchConditional {{.*}} [[MERGE]] [[LOOP]]
1816

1917
__kernel void
@@ -27,6 +25,7 @@ top_scan(__global uint * isums,
2725
int last_thread = (get_local_id(0) < n &&
2826
(get_local_id(0)+1) == n) ? 1 : 0;
2927

28+
#pragma unroll 0
3029
for (int d = 0; d < 16; d++)
3130
{
3231
int idx = get_local_id(0);
@@ -35,7 +34,6 @@ top_scan(__global uint * isums,
3534
{
3635
s_seed += 42;
3736
}
38-
barrier(CLK_LOCAL_MEM_FENCE);
3937
}
4038
}
4139

0 commit comments

Comments
 (0)