Skip to content

Commit cf7c255

Browse files
authored
[JumpThread] Fix JumpThreading pass to skip merging when both blocks contain convergence loop/entry intrinsics. (#170247)
Fixes: #165642. After this fix, optimization passes for the example in the bug. [LLVM Spec](https://llvm.org/docs/ConvergentOperations.html#llvm-experimental-convergence-loop) states that only a single loop / entry convergence token can be included in a basic block. This PR fixes the issue in `JumpThreading` pass so that when a basic block and its predecessor both contain such convergence intrinsics, it skips merging the two blocks.
1 parent 1c72c90 commit cf7c255

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class ReturnInst;
4545
class TargetLibraryInfo;
4646
class Value;
4747

48+
/// Check if the given basic block contains any loop or entry convergent
49+
/// intrinsic instructions.
50+
LLVM_ABI bool HasLoopOrEntryConvergenceToken(const BasicBlock *BB);
51+
4852
/// Replace contents of every block in \p BBs with single unreachable
4953
/// instruction. If \p Updates is specified, collect all necessary DT updates
5054
/// into this vector. If \p KeepOneInputPHIs is true, one-input Phis in

llvm/lib/Transforms/Scalar/JumpThreading.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,13 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) {
19201920
if (Unreachable.count(SinglePred))
19211921
return false;
19221922

1923+
// Don't merge if both the basic block and the predecessor contain loop or
1924+
// entry convergent intrinsics, since there may only be one convergence token
1925+
// per block.
1926+
if (HasLoopOrEntryConvergenceToken(BB) &&
1927+
HasLoopOrEntryConvergenceToken(SinglePred))
1928+
return false;
1929+
19231930
// If SinglePred was a loop header, BB becomes one.
19241931
if (LoopHeaders.erase(SinglePred))
19251932
LoopHeaders.insert(BB);

llvm/lib/Transforms/Utils/BasicBlockUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ emptyAndDetachBlock(BasicBlock *BB,
9292
"applying corresponding DTU updates.");
9393
}
9494

95-
static bool HasLoopOrEntryConvergenceToken(const BasicBlock *BB) {
95+
bool llvm::HasLoopOrEntryConvergenceToken(const BasicBlock *BB) {
9696
for (const Instruction &I : *BB) {
9797
const ConvergenceControlInst *CCI = dyn_cast<ConvergenceControlInst>(&I);
9898
if (CCI && (CCI->isLoop() || CCI->isEntry()))
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -S -passes=jump-threading | FileCheck %s
3+
4+
declare token @llvm.experimental.convergence.entry() #0
5+
6+
define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 {
7+
; CHECK-LABEL: @nested(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
10+
; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TIDY:%.*]], [[TIDX:%.*]]
11+
; CHECK-NEXT: [[OR_COND_I:%.*]] = icmp eq i32 [[TMP1]], 0
12+
; CHECK-NEXT: br label [[FOR_COND_I:%.*]]
13+
; CHECK: for.cond.i:
14+
; CHECK-NEXT: [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
15+
; CHECK-NEXT: br label [[FOR_COND1_I:%.*]]
16+
; CHECK: for.cond1.i:
17+
; CHECK-NEXT: [[CMP2_I:%.*]] = phi i1 [ false, [[FOR_BODY4_I:%.*]] ], [ true, [[FOR_COND_I]] ]
18+
; CHECK-NEXT: [[TMP3:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP2]]) ]
19+
; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[EXIT:%.*]]
20+
; CHECK: for.body4.i:
21+
; CHECK-NEXT: br i1 [[OR_COND_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]]
22+
; CHECK: if.then.i:
23+
; CHECK-NEXT: [[TEST_VAL:%.*]] = call spir_func i32 @func_test(i32 0) [ "convergencectrl"(token [[TMP3]]) ]
24+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 0
25+
; CHECK-NEXT: store i32 [[TEST_VAL]], ptr [[TMP4]], align 4
26+
; CHECK-NEXT: br label [[EXIT]]
27+
; CHECK: exit:
28+
; CHECK-NEXT: ret void
29+
;
30+
entry:
31+
%0 = tail call token @llvm.experimental.convergence.entry()
32+
%2 = or i32 %tidy, %tidx
33+
%or.cond.i = icmp eq i32 %2, 0
34+
br label %for.cond.i
35+
36+
for.cond.i:
37+
%3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
38+
br label %for.cond1.i
39+
40+
for.cond1.i:
41+
%cmp2.i = phi i1 [ false, %for.body4.i ], [ true, %for.cond.i ]
42+
%4 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %3) ]
43+
br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit
44+
45+
for.body4.i:
46+
br i1 %or.cond.i, label %if.then.i, label %for.cond1.i
47+
48+
if.then.i:
49+
%test.val = call spir_func i32 @func_test(i32 0) [ "convergencectrl"(token %4) ]
50+
%5 = getelementptr inbounds i32, ptr %array, i32 0
51+
store i32 %test.val, ptr %5, align 4
52+
br label %cleanup.i
53+
54+
cleanup.i.loopexit:
55+
br label %cleanup.i
56+
57+
cleanup.i:
58+
br label %exit
59+
60+
exit:
61+
ret void
62+
}
63+
64+
declare token @llvm.experimental.convergence.loop() #0
65+
66+
declare i32 @func_test(i32) #0
67+
68+
attributes #0 = { convergent }

0 commit comments

Comments
 (0)