Skip to content

Commit d756218

Browse files
[InstCombine] Add focused assume-based optimizations
This commit implements two targeted optimizations for assume intrinsics: 1. Basic equality optimization: assume(x == c) replaces dominated uses of x with c 2. AMDGPU ballot optimization: assume(ballot(cmp) == -1) replaces dominated uses of cmp with true, since ballot == -1 means cmp is true on all active lanes Key design principles: - No uniformity analysis concepts - uses simple mathematical facts - Dominance-based replacement for correctness - Clean pattern matching without complex framework - Addresses reviewer feedback to keep it simple and focused Examples: assume(x == 42); use = add x, 1 → use = 43 assume(ballot(cmp) == -1); br cmp → br true This enables better optimization of GPU code patterns while remaining architecture-agnostic through the mathematical properties of the operations.
1 parent 4c9c451 commit d756218

File tree

3 files changed

+79
-124
lines changed

3 files changed

+79
-124
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 74 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3549,6 +3549,79 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35493549
}
35503550
}
35513551

3552+
// Basic assume equality optimization: assume(x == c) -> replace dominated uses of x with c
3553+
if (auto *ICmp = dyn_cast<ICmpInst>(IIOperand)) {
3554+
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
3555+
Value *LHS = ICmp->getOperand(0);
3556+
Value *RHS = ICmp->getOperand(1);
3557+
Value *Variable = nullptr;
3558+
Constant *ConstantVal = nullptr;
3559+
3560+
if (auto *C = dyn_cast<Constant>(RHS)) {
3561+
Variable = LHS;
3562+
ConstantVal = C;
3563+
} else if (auto *C = dyn_cast<Constant>(LHS)) {
3564+
Variable = RHS;
3565+
ConstantVal = C;
3566+
}
3567+
3568+
if (Variable && ConstantVal && Variable->hasUseList()) {
3569+
SmallVector<Use *, 8> DominatedUses;
3570+
for (Use &U : Variable->uses()) {
3571+
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3572+
if (UseInst != II && UseInst != ICmp &&
3573+
isValidAssumeForContext(II, UseInst, &DT)) {
3574+
DominatedUses.push_back(&U);
3575+
}
3576+
}
3577+
}
3578+
3579+
for (Use *U : DominatedUses) {
3580+
U->set(ConstantVal);
3581+
Worklist.pushValue(U->getUser());
3582+
}
3583+
3584+
if (!DominatedUses.empty()) {
3585+
Worklist.pushValue(Variable);
3586+
}
3587+
}
3588+
}
3589+
}
3590+
3591+
// Optimize AMDGPU ballot patterns in assumes:
3592+
// assume(ballot(cmp) == -1) means cmp is true on all active lanes
3593+
// We can replace uses of cmp with true in dominated contexts
3594+
Value *BallotInst;
3595+
if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst), m_AllOnes()))) {
3596+
if (auto *IntrCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3597+
if (IntrCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
3598+
Value *BallotArg = IntrCall->getArgOperand(0);
3599+
if (BallotArg->getType()->isIntegerTy(1) && BallotArg->hasUseList()) {
3600+
// Find dominated uses and replace with true
3601+
SmallVector<Use *, 8> DominatedUses;
3602+
for (Use &U : BallotArg->uses()) {
3603+
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3604+
if (UseInst != II && UseInst != IntrCall &&
3605+
isValidAssumeForContext(II, UseInst, &DT)) {
3606+
DominatedUses.push_back(&U);
3607+
}
3608+
}
3609+
}
3610+
3611+
// Replace dominated uses with true
3612+
for (Use *U : DominatedUses) {
3613+
U->set(ConstantInt::getTrue(BallotArg->getType()));
3614+
Worklist.pushValue(U->getUser());
3615+
}
3616+
3617+
if (!DominatedUses.empty()) {
3618+
Worklist.pushValue(BallotArg);
3619+
}
3620+
}
3621+
}
3622+
}
3623+
}
3624+
35523625
// If there is a dominating assume with the same condition as this one,
35533626
// then this one is redundant, and should be removed.
35543627
KnownBits Known(1);
@@ -3562,10 +3635,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35623635
return eraseInstFromFunction(*II);
35633636
}
35643637

3565-
// Try to extract uniformity information from the assume and optimize
3566-
// dominated uses of any variables that are established as uniform.
3567-
optimizeAssumedUniformValues(cast<AssumeInst>(II));
3568-
35693638
// Update the cache of affected values for this assumption (we might be
35703639
// here because we just simplified the condition).
35713640
AC.updateAffectedValues(cast<AssumeInst>(II));
@@ -5031,116 +5100,4 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
50315100
return &Call;
50325101
}
50335102

5034-
/// Extract uniformity information from assume and optimize dominated uses.
5035-
/// This works with any assume pattern that establishes value uniformity.
5036-
void InstCombinerImpl::optimizeAssumedUniformValues(AssumeInst *Assume) {
5037-
Value *AssumedCondition = Assume->getArgOperand(0);
5038-
5039-
// Map of uniform values to their uniform constants
5040-
SmallDenseMap<Value *, Constant *> UniformValues;
5041-
5042-
// Pattern 1: assume(icmp eq (X, C)) -> X is uniform and equals C
5043-
if (auto *ICmp = dyn_cast<ICmpInst>(AssumedCondition)) {
5044-
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
5045-
Value *LHS = ICmp->getOperand(0);
5046-
Value *RHS = ICmp->getOperand(1);
5047-
5048-
// X == constant -> X is uniform and equals constant
5049-
if (auto *C = dyn_cast<Constant>(RHS)) {
5050-
UniformValues[LHS] = C;
5051-
} else if (auto *C = dyn_cast<Constant>(LHS)) {
5052-
UniformValues[RHS] = C;
5053-
}
5054-
5055-
// Handle intrinsic patterns in equality comparisons
5056-
// Pattern: assume(ballot(cmp) == -1) -> cmp is uniform and true
5057-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(LHS)) {
5058-
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5059-
if (match(RHS, m_AllOnes())) {
5060-
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5061-
if (BallotArg->getType()->isIntegerTy(1)) {
5062-
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5063-
5064-
// Special case: if BallotArg is an equality comparison,
5065-
// we know the operands are equal
5066-
if (auto *CmpInst = dyn_cast<ICmpInst>(BallotArg)) {
5067-
if (CmpInst->getPredicate() == ICmpInst::ICMP_EQ) {
5068-
Value *CmpLHS = CmpInst->getOperand(0);
5069-
Value *CmpRHS = CmpInst->getOperand(1);
5070-
5071-
// If one operand is constant, the other is uniform and equals that constant
5072-
if (auto *C = dyn_cast<Constant>(CmpRHS)) {
5073-
UniformValues[CmpLHS] = C;
5074-
} else if (auto *C = dyn_cast<Constant>(CmpLHS)) {
5075-
UniformValues[CmpRHS] = C;
5076-
}
5077-
// TODO: Handle case where both operands are variables
5078-
}
5079-
}
5080-
}
5081-
}
5082-
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5083-
// assume(readfirstlane(x) == c) -> x is uniform and equals c
5084-
if (auto *C = dyn_cast<Constant>(RHS)) {
5085-
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5086-
UniformValues[ReadFirstLaneArg] = C;
5087-
}
5088-
}
5089-
}
5090-
5091-
// Handle the reverse case too
5092-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(RHS)) {
5093-
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5094-
if (match(LHS, m_AllOnes())) {
5095-
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5096-
if (BallotArg->getType()->isIntegerTy(1)) {
5097-
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5098-
}
5099-
}
5100-
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5101-
if (auto *C = dyn_cast<Constant>(LHS)) {
5102-
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5103-
UniformValues[ReadFirstLaneArg] = C;
5104-
}
5105-
}
5106-
}
5107-
}
5108-
}
5109-
5110-
// Pattern 2: assume(X) where X is i1 -> X is uniform and equals true
5111-
if (AssumedCondition->getType()->isIntegerTy(1) && !isa<ICmpInst>(AssumedCondition)) {
5112-
UniformValues[AssumedCondition] = ConstantInt::getTrue(AssumedCondition->getType());
5113-
}
5114-
5115-
// Now optimize dominated uses of all discovered uniform values
5116-
for (auto &[UniformValue, UniformConstant] : UniformValues) {
5117-
SmallVector<Use *, 8> DominatedUses;
5118-
5119-
// Find all uses dominated by the assume
5120-
// Skip if the value doesn't have a use list (e.g., constants)
5121-
if (!UniformValue->hasUseList())
5122-
continue;
5123-
5124-
for (Use &U : UniformValue->uses()) {
5125-
Instruction *UseInst = dyn_cast<Instruction>(U.getUser());
5126-
if (!UseInst || UseInst == Assume)
5127-
continue;
5128-
5129-
// Critical: Check dominance using InstCombine's infrastructure
5130-
if (isValidAssumeForContext(Assume, UseInst, &DT)) {
5131-
DominatedUses.push_back(&U);
5132-
}
5133-
}
5134-
5135-
// Replace only dominated uses with the uniform constant
5136-
for (Use *U : DominatedUses) {
5137-
U->set(UniformConstant);
5138-
Worklist.pushValue(U->getUser());
5139-
}
5140-
5141-
// Mark for further optimization if we made changes
5142-
if (!DominatedUses.empty()) {
5143-
Worklist.pushValue(UniformValue);
5144-
}
5145-
}
5146-
}
5103+

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
230230
private:
231231
bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
232232
bool isDesirableIntType(unsigned BitWidth) const;
233-
234-
/// Optimize uses of variables that are established as uniform by assume intrinsics.
235-
void optimizeAssumedUniformValues(AssumeInst *Assume);
236233
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
237234
bool shouldChangeType(Type *From, Type *To) const;
238235
Value *dyn_castNegVal(Value *V) const;

llvm/test/Transforms/InstCombine/assume.ll

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,12 +1056,13 @@ define i1 @neg_assume_trunc_eq_one(i8 %x) {
10561056
ret i1 %q
10571057
}
10581058

1059-
; Test AMDGPU ballot uniformity pattern optimization
1060-
; This demonstrates that assume(ballot(cmp) == -1) enables the optimization
1061-
; of cmp to true, which then optimizes the branch condition
1059+
; Test AMDGPU ballot pattern optimization
1060+
; assume(ballot(cmp) == -1) means cmp is true on all active lanes
1061+
; so dominated uses of cmp can be replaced with true
10621062
define void @assume_ballot_uniform(i32 %x) {
10631063
; CHECK-LABEL: @assume_ballot_uniform(
1064-
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
1064+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
1065+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
10651066
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
10661067
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
10671068
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]

0 commit comments

Comments
 (0)