Skip to content

Commit c577951

Browse files
[InstCombine] Add focused assume-based optimizations
This commit implements two targeted optimizations for assume intrinsics: 1. Basic equality optimization: assume(x == c) replaces dominated uses of x with c 2. AMDGPU ballot optimization: assume(ballot(cmp) == -1) replaces dominated uses of cmp with true, since ballot == -1 means cmp is true on all active lanes Key design principles: - No uniformity analysis concepts - uses simple mathematical facts - Dominance-based replacement for correctness - Clean pattern matching without complex framework - Addresses reviewer feedback to keep it simple and focused Examples: assume(x == 42); use = add x, 1 → use = 43 assume(ballot(cmp) == -1); br cmp → br true This enables better optimization of GPU code patterns while remaining architecture-agnostic through the mathematical properties of the operations.
1 parent b692ae9 commit c577951

File tree

3 files changed

+79
-124
lines changed

3 files changed

+79
-124
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 74 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3540,6 +3540,79 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35403540
}
35413541
}
35423542

3543+
// Basic assume equality optimization: assume(x == c) -> replace dominated uses of x with c
3544+
if (auto *ICmp = dyn_cast<ICmpInst>(IIOperand)) {
3545+
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
3546+
Value *LHS = ICmp->getOperand(0);
3547+
Value *RHS = ICmp->getOperand(1);
3548+
Value *Variable = nullptr;
3549+
Constant *ConstantVal = nullptr;
3550+
3551+
if (auto *C = dyn_cast<Constant>(RHS)) {
3552+
Variable = LHS;
3553+
ConstantVal = C;
3554+
} else if (auto *C = dyn_cast<Constant>(LHS)) {
3555+
Variable = RHS;
3556+
ConstantVal = C;
3557+
}
3558+
3559+
if (Variable && ConstantVal && Variable->hasUseList()) {
3560+
SmallVector<Use *, 8> DominatedUses;
3561+
for (Use &U : Variable->uses()) {
3562+
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3563+
if (UseInst != II && UseInst != ICmp &&
3564+
isValidAssumeForContext(II, UseInst, &DT)) {
3565+
DominatedUses.push_back(&U);
3566+
}
3567+
}
3568+
}
3569+
3570+
for (Use *U : DominatedUses) {
3571+
U->set(ConstantVal);
3572+
Worklist.pushValue(U->getUser());
3573+
}
3574+
3575+
if (!DominatedUses.empty()) {
3576+
Worklist.pushValue(Variable);
3577+
}
3578+
}
3579+
}
3580+
}
3581+
3582+
// Optimize AMDGPU ballot patterns in assumes:
3583+
// assume(ballot(cmp) == -1) means cmp is true on all active lanes
3584+
// We can replace uses of cmp with true in dominated contexts
3585+
Value *BallotInst;
3586+
if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst), m_AllOnes()))) {
3587+
if (auto *IntrCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3588+
if (IntrCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
3589+
Value *BallotArg = IntrCall->getArgOperand(0);
3590+
if (BallotArg->getType()->isIntegerTy(1) && BallotArg->hasUseList()) {
3591+
// Find dominated uses and replace with true
3592+
SmallVector<Use *, 8> DominatedUses;
3593+
for (Use &U : BallotArg->uses()) {
3594+
if (auto *UseInst = dyn_cast<Instruction>(U.getUser())) {
3595+
if (UseInst != II && UseInst != IntrCall &&
3596+
isValidAssumeForContext(II, UseInst, &DT)) {
3597+
DominatedUses.push_back(&U);
3598+
}
3599+
}
3600+
}
3601+
3602+
// Replace dominated uses with true
3603+
for (Use *U : DominatedUses) {
3604+
U->set(ConstantInt::getTrue(BallotArg->getType()));
3605+
Worklist.pushValue(U->getUser());
3606+
}
3607+
3608+
if (!DominatedUses.empty()) {
3609+
Worklist.pushValue(BallotArg);
3610+
}
3611+
}
3612+
}
3613+
}
3614+
}
3615+
35433616
// If there is a dominating assume with the same condition as this one,
35443617
// then this one is redundant, and should be removed.
35453618
KnownBits Known(1);
@@ -3553,10 +3626,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35533626
return eraseInstFromFunction(*II);
35543627
}
35553628

3556-
// Try to extract uniformity information from the assume and optimize
3557-
// dominated uses of any variables that are established as uniform.
3558-
optimizeAssumedUniformValues(cast<AssumeInst>(II));
3559-
35603629
// Update the cache of affected values for this assumption (we might be
35613630
// here because we just simplified the condition).
35623631
AC.updateAffectedValues(cast<AssumeInst>(II));
@@ -5011,116 +5080,4 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
50115080
return &Call;
50125081
}
50135082

5014-
/// Extract uniformity information from assume and optimize dominated uses.
5015-
/// This works with any assume pattern that establishes value uniformity.
5016-
void InstCombinerImpl::optimizeAssumedUniformValues(AssumeInst *Assume) {
5017-
Value *AssumedCondition = Assume->getArgOperand(0);
5018-
5019-
// Map of uniform values to their uniform constants
5020-
SmallDenseMap<Value *, Constant *> UniformValues;
5021-
5022-
// Pattern 1: assume(icmp eq (X, C)) -> X is uniform and equals C
5023-
if (auto *ICmp = dyn_cast<ICmpInst>(AssumedCondition)) {
5024-
if (ICmp->getPredicate() == ICmpInst::ICMP_EQ) {
5025-
Value *LHS = ICmp->getOperand(0);
5026-
Value *RHS = ICmp->getOperand(1);
5027-
5028-
// X == constant -> X is uniform and equals constant
5029-
if (auto *C = dyn_cast<Constant>(RHS)) {
5030-
UniformValues[LHS] = C;
5031-
} else if (auto *C = dyn_cast<Constant>(LHS)) {
5032-
UniformValues[RHS] = C;
5033-
}
5034-
5035-
// Handle intrinsic patterns in equality comparisons
5036-
// Pattern: assume(ballot(cmp) == -1) -> cmp is uniform and true
5037-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(LHS)) {
5038-
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5039-
if (match(RHS, m_AllOnes())) {
5040-
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5041-
if (BallotArg->getType()->isIntegerTy(1)) {
5042-
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5043-
5044-
// Special case: if BallotArg is an equality comparison,
5045-
// we know the operands are equal
5046-
if (auto *CmpInst = dyn_cast<ICmpInst>(BallotArg)) {
5047-
if (CmpInst->getPredicate() == ICmpInst::ICMP_EQ) {
5048-
Value *CmpLHS = CmpInst->getOperand(0);
5049-
Value *CmpRHS = CmpInst->getOperand(1);
5050-
5051-
// If one operand is constant, the other is uniform and equals that constant
5052-
if (auto *C = dyn_cast<Constant>(CmpRHS)) {
5053-
UniformValues[CmpLHS] = C;
5054-
} else if (auto *C = dyn_cast<Constant>(CmpLHS)) {
5055-
UniformValues[CmpRHS] = C;
5056-
}
5057-
// TODO: Handle case where both operands are variables
5058-
}
5059-
}
5060-
}
5061-
}
5062-
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5063-
// assume(readfirstlane(x) == c) -> x is uniform and equals c
5064-
if (auto *C = dyn_cast<Constant>(RHS)) {
5065-
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5066-
UniformValues[ReadFirstLaneArg] = C;
5067-
}
5068-
}
5069-
}
5070-
5071-
// Handle the reverse case too
5072-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(RHS)) {
5073-
if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
5074-
if (match(LHS, m_AllOnes())) {
5075-
Value *BallotArg = IntrinsicCall->getArgOperand(0);
5076-
if (BallotArg->getType()->isIntegerTy(1)) {
5077-
UniformValues[BallotArg] = ConstantInt::getTrue(BallotArg->getType());
5078-
}
5079-
}
5080-
} else if (IntrinsicCall->getIntrinsicID() == Intrinsic::amdgcn_readfirstlane) {
5081-
if (auto *C = dyn_cast<Constant>(LHS)) {
5082-
Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand(0);
5083-
UniformValues[ReadFirstLaneArg] = C;
5084-
}
5085-
}
5086-
}
5087-
}
5088-
}
5089-
5090-
// Pattern 2: assume(X) where X is i1 -> X is uniform and equals true
5091-
if (AssumedCondition->getType()->isIntegerTy(1) && !isa<ICmpInst>(AssumedCondition)) {
5092-
UniformValues[AssumedCondition] = ConstantInt::getTrue(AssumedCondition->getType());
5093-
}
5094-
5095-
// Now optimize dominated uses of all discovered uniform values
5096-
for (auto &[UniformValue, UniformConstant] : UniformValues) {
5097-
SmallVector<Use *, 8> DominatedUses;
5098-
5099-
// Find all uses dominated by the assume
5100-
// Skip if the value doesn't have a use list (e.g., constants)
5101-
if (!UniformValue->hasUseList())
5102-
continue;
5103-
5104-
for (Use &U : UniformValue->uses()) {
5105-
Instruction *UseInst = dyn_cast<Instruction>(U.getUser());
5106-
if (!UseInst || UseInst == Assume)
5107-
continue;
5108-
5109-
// Critical: Check dominance using InstCombine's infrastructure
5110-
if (isValidAssumeForContext(Assume, UseInst, &DT)) {
5111-
DominatedUses.push_back(&U);
5112-
}
5113-
}
5114-
5115-
// Replace only dominated uses with the uniform constant
5116-
for (Use *U : DominatedUses) {
5117-
U->set(UniformConstant);
5118-
Worklist.pushValue(U->getUser());
5119-
}
5120-
5121-
// Mark for further optimization if we made changes
5122-
if (!DominatedUses.empty()) {
5123-
Worklist.pushValue(UniformValue);
5124-
}
5125-
}
5126-
}
5083+

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
229229
private:
230230
bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
231231
bool isDesirableIntType(unsigned BitWidth) const;
232-
233-
/// Optimize uses of variables that are established as uniform by assume intrinsics.
234-
void optimizeAssumedUniformValues(AssumeInst *Assume);
235232
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
236233
bool shouldChangeType(Type *From, Type *To) const;
237234
Value *dyn_castNegVal(Value *V) const;

llvm/test/Transforms/InstCombine/assume.ll

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,12 +1034,13 @@ define i1 @neg_assume_trunc_eq_one(i8 %x) {
10341034
ret i1 %q
10351035
}
10361036

1037-
; Test AMDGPU ballot uniformity pattern optimization
1038-
; This demonstrates that assume(ballot(cmp) == -1) enables the optimization
1039-
; of cmp to true, which then optimizes the branch condition
1037+
; Test AMDGPU ballot pattern optimization
1038+
; assume(ballot(cmp) == -1) means cmp is true on all active lanes
1039+
; so dominated uses of cmp can be replaced with true
10401040
define void @assume_ballot_uniform(i32 %x) {
10411041
; CHECK-LABEL: @assume_ballot_uniform(
1042-
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 true)
1042+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
1043+
; CHECK-NEXT: [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
10431044
; CHECK-NEXT: [[ALL:%.*]] = icmp eq i64 [[BALLOT]], -1
10441045
; CHECK-NEXT: call void @llvm.assume(i1 [[ALL]])
10451046
; CHECK-NEXT: br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]

0 commit comments

Comments
 (0)