Skip to content

Commit 826f5aa

Browse files
[InstCombine] Optimize AMDGPU ballot + assume uniformity patterns
When we encounter assume(ballot(cmp) == -1), we know that cmp is uniform across all lanes and evaluates to true. This optimization recognizes this pattern and replaces the condition with a constant true, allowing subsequent passes to eliminate dead code and optimize control flow. The optimization handles both i32 and i64 ballot intrinsics and only applies when the ballot result is compared against -1 (all lanes active). This is a conservative approach that ensures correctness while enabling significant optimizations for uniform control flow patterns.
1 parent 3149a77 commit 826f5aa

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3540,6 +3540,39 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35403540
}
35413541
}
35423542

3543+
// Optimize AMDGPU ballot uniformity assumptions:
3544+
// assume(icmp eq (ballot(cmp), -1)) implies that cmp is uniform and true
3545+
// This allows us to optimize away the ballot and replace cmp with true
3546+
Value *BallotInst;
3547+
if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst),
3548+
m_AllOnes()))) {
3549+
// Check if this is an AMDGPU ballot intrinsic
3550+
if (auto *BallotCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3551+
if (BallotCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
3552+
Value *BallotCondition = BallotCall->getArgOperand(0);
3553+
3554+
// If ballot(cmp) == -1, then cmp is uniform across all lanes and
3555+
// evaluates to true We can safely replace BallotCondition with true
3556+
// since ballot == -1 implies all lanes are true
3557+
if (BallotCondition->getType()->isIntOrIntVectorTy(1) &&
3558+
!isa<Constant>(BallotCondition)) {
3559+
3560+
// Add the condition to the worklist for further optimization
3561+
Worklist.pushValue(BallotCondition);
3562+
3563+
// Replace BallotCondition with true
3564+
BallotCondition->replaceAllUsesWith(
3565+
ConstantInt::getTrue(BallotCondition->getType()));
3566+
3567+
// The assumption is now always true, so we can simplify it
3568+
replaceUse(II->getOperandUse(0),
3569+
ConstantInt::getTrue(II->getContext()));
3570+
return II;
3571+
}
3572+
}
3573+
}
3574+
}
3575+
35433576
// If there is a dominating assume with the same condition as this one,
35443577
// then this one is redundant, and should be removed.
35453578
KnownBits Known(1);
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
2+
3+
; Test case for optimizing AMDGPU ballot + assume patterns
4+
; When we assume that ballot(cmp) == -1, we know that cmp is uniform
5+
; This allows us to optimize away the ballot and directly branch
6+
7+
define void @test_assume_ballot_uniform(i32 %x) {
8+
; CHECK-LABEL: @test_assume_ballot_uniform(
9+
; CHECK-NEXT: entry:
10+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
11+
; CHECK: foo:
12+
; CHECK-NEXT: ret void
13+
; CHECK: bar:
14+
; CHECK-NEXT: ret void
15+
;
16+
entry:
17+
%cmp = icmp eq i32 %x, 0
18+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
19+
%all = icmp eq i64 %ballot, -1
20+
call void @llvm.assume(i1 %all)
21+
br i1 %cmp, label %foo, label %bar
22+
23+
foo:
24+
ret void
25+
26+
bar:
27+
ret void
28+
}
29+
30+
; Test case with partial optimization - only ballot removal without branch optimization
31+
define void @test_assume_ballot_partial(i32 %x) {
32+
; CHECK-LABEL: @test_assume_ballot_partial(
33+
; CHECK-NEXT: entry:
34+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
35+
; CHECK: foo:
36+
; CHECK-NEXT: ret void
37+
; CHECK: bar:
38+
; CHECK-NEXT: ret void
39+
;
40+
entry:
41+
%cmp = icmp eq i32 %x, 0
42+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
43+
%all = icmp eq i64 %ballot, -1
44+
call void @llvm.assume(i1 %all)
45+
br i1 %cmp, label %foo, label %bar
46+
47+
foo:
48+
ret void
49+
50+
bar:
51+
ret void
52+
}
53+
54+
; Negative test - ballot not compared to -1
55+
define void @test_assume_ballot_not_uniform(i32 %x) {
56+
; CHECK-LABEL: @test_assume_ballot_not_uniform(
57+
; CHECK-NEXT: entry:
58+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
59+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
60+
; CHECK-NEXT: [[SOME:%.*]] = icmp ne i64 [[BALLOT]], 0
61+
; CHECK-NEXT: call void @llvm.assume(i1 [[SOME]])
62+
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
63+
; CHECK: foo:
64+
; CHECK-NEXT: ret void
65+
; CHECK: bar:
66+
; CHECK-NEXT: ret void
67+
;
68+
entry:
69+
%cmp = icmp eq i32 %x, 0
70+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
71+
%some = icmp ne i64 %ballot, 0
72+
call void @llvm.assume(i1 %some)
73+
br i1 %cmp, label %foo, label %bar
74+
75+
foo:
76+
ret void
77+
78+
bar:
79+
ret void
80+
}
81+
82+
; Test with 32-bit ballot
83+
define void @test_assume_ballot_uniform_i32(i32 %x) {
84+
; CHECK-LABEL: @test_assume_ballot_uniform_i32(
85+
; CHECK-NEXT: entry:
86+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
87+
; CHECK: foo:
88+
; CHECK-NEXT: ret void
89+
; CHECK: bar:
90+
; CHECK-NEXT: ret void
91+
;
92+
entry:
93+
%cmp = icmp eq i32 %x, 0
94+
%ballot = call i32 @llvm.amdgcn.ballot.i32(i1 %cmp)
95+
%all = icmp eq i32 %ballot, -1
96+
call void @llvm.assume(i1 %all)
97+
br i1 %cmp, label %foo, label %bar
98+
99+
foo:
100+
ret void
101+
102+
bar:
103+
ret void
104+
}
105+
106+
declare i64 @llvm.amdgcn.ballot.i64(i1)
107+
declare i32 @llvm.amdgcn.ballot.i32(i1)
108+
declare void @llvm.assume(i1)

0 commit comments

Comments
 (0)