Skip to content

Commit 98274dd

Browse files
[InstCombine] Add constant folding for AMDGPU ballot intrinsics
Address reviewer feedback by implementing free-form ballot intrinsic optimization instead of assume-dependent patterns. This approach: 1. Optimizes ballot(constant) directly as a standard intrinsic optimization 2. Allows uniformity analysis to handle assumes through proper channels 3. Follows established AMDGPU intrinsic patterns (amdgcn_cos, amdgcn_sin) 4. Enables broader optimization opportunities beyond assume contexts Optimizations: - ballot(true) -> -1 (all lanes active) - ballot(false) -> 0 (no lanes active) This addresses the core reviewer concern about performing optimization in assume context rather than as a free-form pattern, and lets the uniformity analysis framework handle assumes as intended. Test cases focus on constant folding rather than assume-specific patterns, demonstrating the more general applicability of this approach.
1 parent 1bd8d44 commit 98274dd

File tree

4 files changed

+130
-140
lines changed

4 files changed

+130
-140
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ using namespace PatternMatch;
8585

8686
STATISTIC(NumSimplified, "Number of library calls simplified");
8787

88+
89+
8890
static cl::opt<unsigned> GuardWideningWindow(
8991
"instcombine-guard-widening-window",
9092
cl::init(3),
@@ -2996,6 +2998,20 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
29962998
}
29972999
break;
29983000
}
3001+
case Intrinsic::amdgcn_ballot: {
3002+
// Optimize ballot intrinsics when the condition is known to be uniform
3003+
Value *Condition = II->getArgOperand(0);
3004+
3005+
// If the condition is a constant, we can evaluate the ballot directly
3006+
if (auto *ConstCond = dyn_cast<ConstantInt>(Condition)) {
3007+
// ballot(true) -> -1 (all lanes active)
3008+
// ballot(false) -> 0 (no lanes active)
3009+
uint64_t Result = ConstCond->isOne() ? ~0ULL : 0ULL;
3010+
return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));
3011+
}
3012+
3013+
break;
3014+
}
29993015
case Intrinsic::ldexp: {
30003016
// ldexp(ldexp(x, a), b) -> ldexp(x, a + b)
30013017
//
@@ -3549,38 +3565,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35493565
}
35503566
}
35513567

3552-
// Optimize AMDGPU ballot uniformity assumptions:
3553-
// assume(icmp eq (ballot(cmp), -1)) implies that cmp is uniform and true
3554-
// This allows us to optimize away the ballot and replace cmp with true
3555-
Value *BallotInst;
3556-
if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst),
3557-
m_AllOnes()))) {
3558-
// Check if this is an AMDGPU ballot intrinsic
3559-
if (auto *BallotCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3560-
if (BallotCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
3561-
Value *BallotCondition = BallotCall->getArgOperand(0);
3562-
3563-
// If ballot(cmp) == -1, then cmp is uniform across all lanes and
3564-
// evaluates to true We can safely replace BallotCondition with true
3565-
// since ballot == -1 implies all lanes are true
3566-
if (BallotCondition->getType()->isIntOrIntVectorTy(1) &&
3567-
!isa<Constant>(BallotCondition)) {
3568-
3569-
// Add the condition to the worklist for further optimization
3570-
Worklist.pushValue(BallotCondition);
3571-
3572-
// Replace BallotCondition with true
3573-
BallotCondition->replaceAllUsesWith(
3574-
ConstantInt::getTrue(BallotCondition->getType()));
3575-
3576-
// The assumption is now always true, so we can simplify it
3577-
replaceUse(II->getOperandUse(0),
3578-
ConstantInt::getTrue(II->getContext()));
3579-
return II;
3580-
}
3581-
}
3582-
}
3583-
}
3568+
35843569

35853570
// If there is a dominating assume with the same condition as this one,
35863571
// then this one is redundant, and should be removed.
@@ -3595,6 +3580,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35953580
return eraseInstFromFunction(*II);
35963581
}
35973582

3583+
3584+
35983585
// Update the cache of affected values for this assumption (we might be
35993586
// here because we just simplified the condition).
36003587
AC.updateAffectedValues(cast<AssumeInst>(II));

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
124124
BinaryOperator &I);
125125
Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
126126
BinaryOperator &OldAShr);
127+
128+
127129
Instruction *visitAShr(BinaryOperator &I);
128130
Instruction *visitLShr(BinaryOperator &I);
129131
Instruction *commonShiftTransforms(BinaryOperator &I);

llvm/test/Transforms/InstCombine/amdgpu-assume-ballot-uniform.ll

Lines changed: 0 additions & 108 deletions
This file was deleted.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
2+
3+
; Test cases for optimizing AMDGPU ballot intrinsics
4+
; Focus on constant folding ballot(true) -> -1 and ballot(false) -> 0
5+
6+
define void @test_ballot_constant_true() {
7+
; CHECK-LABEL: @test_ballot_constant_true(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 -1, -1
10+
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
11+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
12+
; CHECK: foo:
13+
; CHECK-NEXT: ret void
14+
; CHECK: bar:
15+
; CHECK-NEXT: ret void
16+
;
17+
entry:
18+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 true)
19+
%all = icmp eq i64 %ballot, -1
20+
call void @llvm.assume(i1 %all)
21+
br i1 true, label %foo, label %bar
22+
23+
foo:
24+
ret void
25+
26+
bar:
27+
ret void
28+
}
29+
30+
define void @test_ballot_constant_false() {
31+
; CHECK-LABEL: @test_ballot_constant_false(
32+
; CHECK-NEXT: entry:
33+
; CHECK-NEXT: [[NONE:%.*]] = icmp ne i64 0, 0
34+
; CHECK-NEXT: call void @llvm.assume(i1 [[NONE]])
35+
; CHECK-NEXT: br i1 false, label [[FOO:%.*]], label [[BAR:%.*]]
36+
; CHECK: foo:
37+
; CHECK-NEXT: ret void
38+
; CHECK: bar:
39+
; CHECK-NEXT: ret void
40+
;
41+
entry:
42+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 false)
43+
%none = icmp ne i64 %ballot, 0
44+
call void @llvm.assume(i1 %none)
45+
br i1 false, label %foo, label %bar
46+
47+
foo:
48+
ret void
49+
50+
bar:
51+
ret void
52+
}
53+
54+
; Test with 32-bit ballot constants
55+
define void @test_ballot_i32_constant_true() {
56+
; CHECK-LABEL: @test_ballot_i32_constant_true(
57+
; CHECK-NEXT: entry:
58+
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i32 -1, -1
59+
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
60+
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
61+
; CHECK: foo:
62+
; CHECK-NEXT: ret void
63+
; CHECK: bar:
64+
; CHECK-NEXT: ret void
65+
;
66+
entry:
67+
%ballot = call i32 @llvm.amdgcn.ballot.i32(i1 true)
68+
%all = icmp eq i32 %ballot, -1
69+
call void @llvm.assume(i1 %all)
70+
br i1 true, label %foo, label %bar
71+
72+
foo:
73+
ret void
74+
75+
bar:
76+
ret void
77+
}
78+
79+
; Negative test - variable condition should not be optimized
80+
define void @test_ballot_variable_condition(i32 %x) {
81+
; CHECK-LABEL: @test_ballot_variable_condition(
82+
; CHECK-NEXT: entry:
83+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
84+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
85+
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
86+
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
87+
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
88+
; CHECK: foo:
89+
; CHECK-NEXT: ret void
90+
; CHECK: bar:
91+
; CHECK-NEXT: ret void
92+
;
93+
entry:
94+
%cmp = icmp eq i32 %x, 0
95+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
96+
%all = icmp eq i64 %ballot, -1
97+
call void @llvm.assume(i1 %all)
98+
br i1 %cmp, label %foo, label %bar
99+
100+
foo:
101+
ret void
102+
103+
bar:
104+
ret void
105+
}
106+
107+
declare i64 @llvm.amdgcn.ballot.i64(i1)
108+
declare i32 @llvm.amdgcn.ballot.i32(i1)
109+
declare void @llvm.assume(i1)

0 commit comments

Comments
 (0)