Skip to content

Commit 2947714

Browse files
[InstCombine] Implement generic assume-based uniformity optimization
Implement a comprehensive generic optimization for assume intrinsics that extracts uniformity information and optimizes dominated uses. The optimization recognizes multiple patterns that establish value uniformity and replaces dominated uses with uniform constants. Addresses uniformity analysis optimization opportunities identified in AMDGPU ballot/readfirstlane + assume patterns for improved code generation through constant propagation.
1 parent 654e365 commit 2947714

File tree

6 files changed

+195
-113
lines changed

6 files changed

+195
-113
lines changed

.github/copilot-instructions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
When performing a code review, pay close attention to code modifying a function's
2+
control flow. Could the change result in the corruption of performance profile
3+
data? Could the change result in invalid debug information, in particular for
4+
branches and calls?

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,12 +1322,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13221322
if (isa<PoisonValue>(Arg))
13231323
return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
13241324

1325-
if (auto *Src = dyn_cast<ConstantInt>(Arg)) {
1326-
if (Src->isZero()) {
1327-
// amdgcn.ballot(i1 0) is zero.
1328-
return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1329-
}
1330-
}
1325+
// For Wave32 targets, convert i64 ballot to i32 ballot + zext
13311326
if (ST->isWave32() && II.getType()->getIntegerBitWidth() == 64) {
13321327
// %b64 = call i64 ballot.i64(...)
13331328
// =>
@@ -1341,6 +1336,15 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13411336
Call->takeName(&II);
13421337
return IC.replaceInstUsesWith(II, Call);
13431338
}
1339+
1340+
if (auto *Src = dyn_cast<ConstantInt>(Arg)) {
1341+
if (Src->isZero()) {
1342+
// amdgcn.ballot(i1 0) is zero.
1343+
return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1344+
}
1345+
// Note: ballot(true) is NOT constant folded because the result depends
1346+
// on the active lanes in the wavefront, not just the condition value.
1347+
}
13441348
break;
13451349
}
13461350
case Intrinsic::amdgcn_wavefrontsize: {

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ using namespace PatternMatch;
8585

8686
STATISTIC(NumSimplified, "Number of library calls simplified");
8787

88-
89-
9088
static cl::opt<unsigned> GuardWideningWindow(
9189
"instcombine-guard-widening-window",
9290
cl::init(3),
@@ -2998,20 +2996,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
29982996
}
29992997
break;
30002998
}
3001-
case Intrinsic::amdgcn_ballot: {
3002-
// Optimize ballot intrinsics when the condition is known to be uniform
3003-
Value *Condition = II->getArgOperand(0);
3004-
3005-
// If the condition is a constant, we can evaluate the ballot directly
3006-
if (auto *ConstCond = dyn_cast<ConstantInt>(Condition)) {
3007-
// ballot(true) -> -1 (all lanes active)
3008-
// ballot(false) -> 0 (no lanes active)
3009-
uint64_t Result = ConstCond->isOne() ? ~0ULL : 0ULL;
3010-
return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));
3011-
}
3012-
3013-
break;
3014-
}
30152999
case Intrinsic::ldexp: {
30163000
// ldexp(ldexp(x, a), b) -> ldexp(x, a + b)
30173001
//
@@ -3565,8 +3549,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35653549
}
35663550
}
35673551

3568-
3569-
35703552
// If there is a dominating assume with the same condition as this one,
35713553
// then this one is redundant, and should be removed.
35723554
KnownBits Known(1);
@@ -3580,7 +3562,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35803562
return eraseInstFromFunction(*II);
35813563
}
35823564

3583-
3565+
// Try to extract uniformity information from the assume and optimize
3566+
// dominated uses of any variables that are established as uniform.
3567+
optimizeAssumedUniformValues(cast<AssumeInst>(II));
35843568

35853569
// Update the cache of affected values for this assumption (we might be
35863570
// here because we just simplified the condition).
@@ -5046,3 +5030,117 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
50465030
Call.setCalledFunction(FTy, NestF);
50475031
return &Call;
50485032
}
5033+
5034+
/// Extract uniformity information from assume and optimize dominated uses.
5035+
/// This works with any assume pattern that establishes value uniformity.
5036+
void InstCombinerImpl::optimizeAssumedUniformValues(AssumeInst *Assume) {
5037+
Value *AssumedCondition = Assume->getArgOperand(0);
5038+
5039+
// Map of uniform values to their uniform constants
5040+
SmallDenseMap<Value *, Constant *> UniformValues;
5041+
5042+
// Pattern 1: assume(icmp eq (X, C)) -> X is uniform and equals C
5043+
if (auto *ICmp = dyn_cast<ICmpInst>(AssumedCondition)) {
5044+
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
5045+
Value *LHS = ICmp->getOperand(0);
5046+
Value *RHS = ICmp->getOperand(1);
5047+
5048+
// X == constant -> X is uniform and equals constant
5049+
if (auto *C = dyn_cast<Constant>(RHS)) {
5050+
UniformValues[LHS] = C;
5051+
} else if (auto *C = dyn_cast<Constant>(LHS)) {
5052+
UniformValues[RHS] = C;
5053+
}
5054+
5055+
// Handle intrinsic patterns in equality comparisons
5056+
// Pattern: assume(ballot(cmp) == -1) -> cmp is uniform and true
5057+
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(LHS)) {
5058+
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5059+
if (match(RHS, m_AllOnes())) {
5060+
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5061+
if (BallotArg->getType()->isIntegerTy(1)) {
5062+
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5063+
5064+
// Special case: if BallotArg is an equality comparison,
5065+
// we know the operands are equal
5066+
if (auto *CmpInst = dyn_cast<ICmpInst>(BallotArg)) {
5067+
if (CmpInst->getPredicate() == ICmpInst::ICMP_EQ) {
5068+
Value *CmpLHS = CmpInst->getOperand(0);
5069+
Value *CmpRHS = CmpInst->getOperand(1);
5070+
5071+
// If one operand is constant, the other is uniform and equals that constant
5072+
if (auto *C = dyn_cast<Constant>(CmpRHS)) {
5073+
UniformValues[CmpLHS] = C;
5074+
} else if (auto *C = dyn_cast<Constant>(CmpLHS)) {
5075+
UniformValues[CmpRHS] = C;
5076+
}
5077+
// TODO: Handle case where both operands are variables
5078+
}
5079+
}
5080+
}
5081+
}
5082+
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5083+
// assume(readfirstlane(x) == c) -> x is uniform and equals c
5084+
if (auto *C = dyn_cast<Constant>(RHS)) {
5085+
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5086+
UniformValues[ReadFirstLaneArg] = C;
5087+
}
5088+
}
5089+
}
5090+
5091+
// Handle the reverse case too
5092+
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(RHS)) {
5093+
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5094+
if (match(LHS, m_AllOnes())) {
5095+
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5096+
if (BallotArg->getType()->isIntegerTy(1)) {
5097+
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5098+
}
5099+
}
5100+
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5101+
if (auto *C = dyn_cast<Constant>(LHS)) {
5102+
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5103+
UniformValues[ReadFirstLaneArg] = C;
5104+
}
5105+
}
5106+
}
5107+
}
5108+
}
5109+
5110+
// Pattern 2: assume(X) where X is i1 -> X is uniform and equals true
5111+
if (AssumedCondition->getType()->isIntegerTy(1) && !isa<ICmpInst>(AssumedCondition)) {
5112+
UniformValues[AssumedCondition] = ConstantInt::getTrue(AssumedCondition->getType());
5113+
}
5114+
5115+
// Now optimize dominated uses of all discovered uniform values
5116+
for (auto &[UniformValue, UniformConstant] : UniformValues) {
5117+
SmallVector<Use *, 8> DominatedUses;
5118+
5119+
// Find all uses dominated by the assume
5120+
// Skip if the value doesn't have a use list (e.g., constants)
5121+
if (!UniformValue->hasUseList())
5122+
continue;
5123+
5124+
for (Use &U : UniformValue->uses()) {
5125+
Instruction *UseInst = dyn_cast<Instruction>(U.getUser());
5126+
if (!UseInst || UseInst == Assume)
5127+
continue;
5128+
5129+
// Critical: Check dominance using InstCombine's infrastructure
5130+
if (isValidAssumeForContext(Assume, UseInst, &DT)) {
5131+
DominatedUses.push_back(&U);
5132+
}
5133+
}
5134+
5135+
// Replace only dominated uses with the uniform constant
5136+
for (Use *U : DominatedUses) {
5137+
U->set(UniformConstant);
5138+
Worklist.pushValue(U->getUser());
5139+
}
5140+
5141+
// Mark for further optimization if we made changes
5142+
if (!DominatedUses.empty()) {
5143+
Worklist.pushValue(UniformValue);
5144+
}
5145+
}
5146+
}

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
124124
BinaryOperator &I);
125125
Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
126126
BinaryOperator &OldAShr);
127-
128-
129127
Instruction *visitAShr(BinaryOperator &I);
130128
Instruction *visitLShr(BinaryOperator &I);
131129
Instruction *commonShiftTransforms(BinaryOperator &I);
@@ -232,6 +230,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
232230
private:
233231
bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
234232
bool isDesirableIntType(unsigned BitWidth) const;
233+
234+
/// Optimize uses of variables that are established as uniform by assume intrinsics.
235+
void optimizeAssumedUniformValues(AssumeInst *Assume);
235236
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
236237
bool shouldChangeType(Type *From, Type *To) const;
237238
Value *dyn_castNegVal(Value *V) const;
Lines changed: 32 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,56 @@
1-
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
1+
; RUN: opt < %s -mtriple=amdgcn-amd-amdhsa -passes=instcombine -S | FileCheck %s
22

33
; Test cases for optimizing AMDGPU ballot intrinsics
4-
; Focus on constant folding ballot(true) -> -1 and ballot(false) -> 0
4+
; Focus on constant folding ballot(false) -> 0 and poison handling
55

6-
define void @test_ballot_constant_true() {
7-
; CHECK-LABEL: @test_ballot_constant_true(
8-
; CHECK-NEXT: entry:
9-
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 -1, -1
10-
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
11-
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
12-
; CHECK: foo:
13-
; CHECK-NEXT: ret void
14-
; CHECK: bar:
15-
; CHECK-NEXT: ret void
6+
; Test ballot with constant false condition gets folded
7+
define i32 @test_ballot_constant_false() {
8+
; CHECK-LABEL: @test_ballot_constant_false(
9+
; CHECK-NEXT: ret i32 0
1610
;
17-
entry:
18-
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 true)
19-
%all = icmp eq i64 %ballot, -1
20-
call void @llvm.assume(i1 %all)
21-
br i1 true, label %foo, label %bar
22-
23-
foo:
24-
ret void
25-
26-
bar:
27-
ret void
11+
%ballot = call i32 @llvm.amdgcn.ballot.i32(i1 false)
12+
ret i32 %ballot
2813
}
2914

30-
define void @test_ballot_constant_false() {
31-
; CHECK-LABEL: @test_ballot_constant_false(
32-
; CHECK-NEXT: entry:
33-
; CHECK-NEXT: [[NONE:%.*]] = icmp ne i64 0, 0
34-
; CHECK-NEXT: call void @llvm.assume(i1 [[NONE]])
35-
; CHECK-NEXT: br i1 false, label [[FOO:%.*]], label [[BAR:%.*]]
36-
; CHECK: foo:
37-
; CHECK-NEXT: ret void
38-
; CHECK: bar:
39-
; CHECK-NEXT: ret void
15+
; Test ballot.i64 with constant false condition gets folded
16+
define i64 @test_ballot_i64_constant_false() {
17+
; CHECK-LABEL: @test_ballot_i64_constant_false(
18+
; CHECK-NEXT: ret i64 0
4019
;
41-
entry:
4220
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 false)
43-
%none = icmp ne i64 %ballot, 0
44-
call void @llvm.assume(i1 %none)
45-
br i1 false, label %foo, label %bar
46-
47-
foo:
48-
ret void
49-
50-
bar:
51-
ret void
21+
ret i64 %ballot
5222
}
5323

54-
; Test with 32-bit ballot constants
55-
define void @test_ballot_i32_constant_true() {
56-
; CHECK-LABEL: @test_ballot_i32_constant_true(
57-
; CHECK-NEXT: entry:
58-
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i32 -1, -1
59-
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
60-
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
61-
; CHECK: foo:
62-
; CHECK-NEXT: ret void
63-
; CHECK: bar:
64-
; CHECK-NEXT: ret void
24+
; Test ballot with poison condition gets folded to poison
25+
define i64 @test_ballot_poison() {
26+
; CHECK-LABEL: @test_ballot_poison(
27+
; CHECK-NEXT: ret i64 poison
6528
;
66-
entry:
67-
%ballot = call i32 @llvm.amdgcn.ballot.i32(i1 true)
68-
%all = icmp eq i32 %ballot, -1
69-
call void @llvm.assume(i1 %all)
70-
br i1 true, label %foo, label %bar
71-
72-
foo:
73-
ret void
29+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 poison)
30+
ret i64 %ballot
31+
}
7432

75-
bar:
76-
ret void
33+
; Test that ballot(true) is NOT constant folded (depends on active lanes)
34+
define i64 @test_ballot_constant_true() {
35+
; CHECK-LABEL: @test_ballot_constant_true(
36+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
37+
; CHECK-NEXT: ret i64 [[BALLOT]]
38+
;
39+
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 true)
40+
ret i64 %ballot
7741
}
7842

79-
; Negative test - variable condition should not be optimized
80-
define void @test_ballot_variable_condition(i32 %x) {
43+
; Test that ballot with variable condition is not optimized
44+
define i64 @test_ballot_variable_condition(i32 %x) {
8145
; CHECK-LABEL: @test_ballot_variable_condition(
82-
; CHECK-NEXT: entry:
8346
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
8447
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
85-
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
86-
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
87-
; CHECK-NEXT: br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
88-
; CHECK: foo:
89-
; CHECK-NEXT: ret void
90-
; CHECK: bar:
91-
; CHECK-NEXT: ret void
48+
; CHECK-NEXT: ret i64 [[BALLOT]]
9249
;
93-
entry:
9450
%cmp = icmp eq i32 %x, 0
9551
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
96-
%all = icmp eq i64 %ballot, -1
97-
call void @llvm.assume(i1 %all)
98-
br i1 %cmp, label %foo, label %bar
99-
100-
foo:
101-
ret void
102-
103-
bar:
104-
ret void
52+
ret i64 %ballot
10553
}
10654

10755
declare i64 @llvm.amdgcn.ballot.i64(i1)
10856
declare i32 @llvm.amdgcn.ballot.i32(i1)
109-
declare void @llvm.assume(i1)

0 commit comments

Comments
 (0)