Skip to content

Commit 3564791

Browse files
authored
[IndVarSimplify] Fix IndVarSimplify to skip unfolding predicates when the loop contains control convergence operations. (llvm#165643)
Skip constant folding the loop predicates if the loop contains control convergence tokens referenced outside the loop. Fixes llvm#164496. Verified [loop_peeling.test](llvm/offload-test-suite#473) passes with the fix. Similar control convergence issues are found on other passes. llvm#165642 HLSL used for tests: ```hlsl RWStructuredBuffer<uint> Out : register(u0); [numthreads(8,1,1)] void main(uint3 TID : SV_GroupThreadID) { for (uint i = 0; i < 8; i++) { if (i == TID.x) { Out[TID.x] = WaveActiveMax(TID.x); break; } } } ``` With nested loop: ```hlsl RWStructuredBuffer<uint> Out : register(u0); [numthreads(8,8,1)] void main(uint3 TID : SV_GroupThreadID) { for (uint i = 0; i < 8; i++) { for (uint j = 0; j < 8; j++) { if (i == TID.x && j == TID.y) { uint index = TID.x * 8 + TID.y; Out[index] = WaveActiveMax(index); break; } } } } ```
1 parent b78b5ba commit 3564791

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed

llvm/lib/Transforms/Scalar/IndVarSimplify.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
18551855
// is that enough for *all* side effects?
18561856
bool HasThreadLocalSideEffects = false;
18571857
for (BasicBlock *BB : L->blocks())
1858-
for (auto &I : *BB)
1858+
for (auto &I : *BB) {
18591859
// TODO:isGuaranteedToTransfer
18601860
if (I.mayHaveSideEffects()) {
18611861
if (!LoopPredicationTraps)
@@ -1873,6 +1873,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
18731873
}
18741874
}
18751875

1876+
// Skip if the loop has tokens referenced outside the loop to avoid
1877+
// changing convergence behavior.
1878+
if (I.getType()->isTokenTy()) {
1879+
for (User *U : I.users()) {
1880+
Instruction *UserInst = dyn_cast<Instruction>(U);
1881+
if (UserInst && !L->contains(UserInst)) {
1882+
return false;
1883+
}
1884+
}
1885+
}
1886+
}
1887+
18761888
bool Changed = false;
18771889
// Finally, do the actual predication for all predicatable blocks. A couple
18781890
// of notes here:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
3+
4+
; Loop with body using loop convergence token should be skipped by IndVarSimplify.
5+
6+
declare token @llvm.experimental.convergence.entry() #0
7+
8+
define void @loop(i32 %tid, ptr %array) #0 {
9+
; CHECK-LABEL: @loop(
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
12+
; CHECK-NEXT: br label [[FOR_COND_I:%.*]]
13+
; CHECK: for.cond.i:
14+
; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ]
15+
; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
16+
; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
17+
; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[EXIT_LOOPEXIT:%.*]]
18+
; CHECK: for.body.i:
19+
; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TID:%.*]]
20+
; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1
21+
; CHECK-NEXT: br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]]
22+
; CHECK: exit.loopexit:
23+
; CHECK-NEXT: br label [[EXIT:%.*]]
24+
; CHECK: if.then.i:
25+
; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TID]]) [ "convergencectrl"(token [[TMP1]]) ]
26+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[TID]]
27+
; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr [[TMP2]], align 4
28+
; CHECK-NEXT: br label [[EXIT]]
29+
; CHECK: exit:
30+
; CHECK-NEXT: ret void
31+
;
32+
entry:
33+
%0 = tail call token @llvm.experimental.convergence.entry()
34+
br label %for.cond.i
35+
36+
for.cond.i:
37+
%i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ]
38+
%2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
39+
%cmp.i = icmp ult i32 %i.0.i, 8
40+
br i1 %cmp.i, label %for.body.i, label %exit.loopexit
41+
42+
for.body.i:
43+
%cmp1.i = icmp eq i32 %i.0.i, %tid
44+
%inc.i = add nuw nsw i32 %i.0.i, 1
45+
br i1 %cmp1.i, label %if.then.i, label %for.cond.i
46+
47+
exit.loopexit:
48+
br label %exit
49+
50+
if.then.i:
51+
%hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %tid) [ "convergencectrl"(token %2) ]
52+
%3 = getelementptr inbounds i32, ptr %array, i32 %tid
53+
store i32 %hlsl.wave.active.max2.i, ptr %3, align 4
54+
br label %exit
55+
56+
exit:
57+
ret void
58+
}
59+
60+
declare token @llvm.experimental.convergence.loop() #0
61+
62+
declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
63+
64+
attributes #0 = { convergent }
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
3+
4+
; Nested loops with body using loop convergence token should be skipped by IndVarSimplify.
5+
6+
declare token @llvm.experimental.convergence.entry() #0
7+
8+
define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 {
9+
; CHECK-LABEL: @nested(
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
12+
; CHECK-NEXT: [[MUL_I:%.*]] = shl nsw i32 [[TIDX:%.*]], 3
13+
; CHECK-NEXT: [[ADD_I:%.*]] = add nsw i32 [[MUL_I]], [[TIDY:%.*]]
14+
; CHECK-NEXT: br label [[FOR_COND_I:%.*]]
15+
; CHECK: for.cond.i:
16+
; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ]
17+
; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
18+
; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
19+
; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[EXIT:%.*]]
20+
; CHECK: for.cond1.i.preheader:
21+
; CHECK-NEXT: [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TIDX]]
22+
; CHECK-NEXT: br label [[FOR_COND1_I:%.*]]
23+
; CHECK: for.cond1.i:
24+
; CHECK-NEXT: [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ]
25+
; CHECK-NEXT: [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP1]]) ]
26+
; CHECK-NEXT: [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8
27+
; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]]
28+
; CHECK: for.body4.i:
29+
; CHECK-NEXT: [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TIDY]]
30+
; CHECK-NEXT: [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false
31+
; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1
32+
; CHECK-NEXT: br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]]
33+
; CHECK: cleanup.i.loopexit:
34+
; CHECK-NEXT: br label [[CLEANUP_I]]
35+
; CHECK: if.then.i:
36+
; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP2]]) ]
37+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[ADD_I]]
38+
; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr [[TMP3]], align 4
39+
; CHECK-NEXT: br label [[CLEANUP_I]]
40+
; CHECK: cleanup.i:
41+
; CHECK-NEXT: [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1
42+
; CHECK-NEXT: br label [[FOR_COND_I]]
43+
; CHECK: exit:
44+
; CHECK-NEXT: ret void
45+
;
46+
entry:
47+
%0 = tail call token @llvm.experimental.convergence.entry()
48+
%mul.i = shl nsw i32 %tidx, 3
49+
%add.i = add nsw i32 %mul.i, %tidy
50+
br label %for.cond.i
51+
52+
for.cond.i:
53+
%i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ]
54+
%2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
55+
%cmp.i = icmp ult i32 %i.0.i, 8
56+
br i1 %cmp.i, label %for.cond1.i.preheader, label %exit
57+
58+
for.cond1.i.preheader:
59+
%cmp5.i = icmp eq i32 %i.0.i, %tidx
60+
br label %for.cond1.i
61+
62+
for.cond1.i:
63+
%j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ]
64+
%3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %2) ]
65+
%cmp2.i = icmp ult i32 %j.0.i, 8
66+
br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit
67+
68+
for.body4.i:
69+
%cmp6.i = icmp eq i32 %j.0.i, %tidy
70+
%or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false
71+
%inc.i = add nsw i32 %j.0.i, 1
72+
br i1 %or.cond, label %if.then.i, label %for.cond1.i
73+
74+
cleanup.i.loopexit:
75+
br label %cleanup.i
76+
77+
if.then.i:
78+
%hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %3) ]
79+
%4 = getelementptr inbounds i32, ptr %array, i32 %add.i
80+
store i32 %hlsl.wave.active.max7.i, ptr %4, align 4
81+
br label %cleanup.i
82+
83+
cleanup.i:
84+
%inc10.i = add nsw i32 %i.0.i, 1
85+
br label %for.cond.i
86+
87+
exit:
88+
ret void
89+
}
90+
91+
declare token @llvm.experimental.convergence.loop() #0
92+
93+
declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
94+
95+
attributes #0 = { convergent }

0 commit comments

Comments
 (0)