@@ -3540,6 +3540,79 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
3540
3540
}
3541
3541
}
3542
3542
3543
+ // Basic assume equality optimization: assume(x == c) -> replace dominated uses of x with c
3544
+ if (auto *ICmp = dyn_cast<ICmpInst>(IIOperand)) {
3545
+ if (ICmp->getPredicate () == ICmpInst::ICMP_EQ) {
3546
+ Value *LHS = ICmp->getOperand (0 );
3547
+ Value *RHS = ICmp->getOperand (1 );
3548
+ Value *Variable = nullptr ;
3549
+ Constant *ConstantVal = nullptr ;
3550
+
3551
+ if (auto *C = dyn_cast<Constant>(RHS)) {
3552
+ Variable = LHS;
3553
+ ConstantVal = C;
3554
+ } else if (auto *C = dyn_cast<Constant>(LHS)) {
3555
+ Variable = RHS;
3556
+ ConstantVal = C;
3557
+ }
3558
+
3559
+ if (Variable && ConstantVal && Variable->hasUseList ()) {
3560
+ SmallVector<Use *, 8 > DominatedUses;
3561
+ for (Use &U : Variable->uses ()) {
3562
+ if (auto *UseInst = dyn_cast<Instruction>(U.getUser ())) {
3563
+ if (UseInst != II && UseInst != ICmp &&
3564
+ isValidAssumeForContext (II, UseInst, &DT)) {
3565
+ DominatedUses.push_back (&U);
3566
+ }
3567
+ }
3568
+ }
3569
+
3570
+ for (Use *U : DominatedUses) {
3571
+ U->set (ConstantVal);
3572
+ Worklist.pushValue (U->getUser ());
3573
+ }
3574
+
3575
+ if (!DominatedUses.empty ()) {
3576
+ Worklist.pushValue (Variable);
3577
+ }
3578
+ }
3579
+ }
3580
+ }
3581
+
3582
+ // Optimize AMDGPU ballot patterns in assumes:
3583
+ // assume(ballot(cmp) == -1) means cmp is true on all active lanes
3584
+ // We can replace uses of cmp with true in dominated contexts
3585
+ Value *BallotInst;
3586
+ if (match (IIOperand, m_SpecificICmp (ICmpInst::ICMP_EQ, m_Value (BallotInst), m_AllOnes ()))) {
3587
+ if (auto *IntrCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3588
+ if (IntrCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
3589
+ Value *BallotArg = IntrCall->getArgOperand (0 );
3590
+ if (BallotArg->getType ()->isIntegerTy (1 ) && BallotArg->hasUseList ()) {
3591
+ // Find dominated uses and replace with true
3592
+ SmallVector<Use *, 8 > DominatedUses;
3593
+ for (Use &U : BallotArg->uses ()) {
3594
+ if (auto *UseInst = dyn_cast<Instruction>(U.getUser ())) {
3595
+ if (UseInst != II && UseInst != IntrCall &&
3596
+ isValidAssumeForContext (II, UseInst, &DT)) {
3597
+ DominatedUses.push_back (&U);
3598
+ }
3599
+ }
3600
+ }
3601
+
3602
+ // Replace dominated uses with true
3603
+ for (Use *U : DominatedUses) {
3604
+ U->set (ConstantInt::getTrue (BallotArg->getType ()));
3605
+ Worklist.pushValue (U->getUser ());
3606
+ }
3607
+
3608
+ if (!DominatedUses.empty ()) {
3609
+ Worklist.pushValue (BallotArg);
3610
+ }
3611
+ }
3612
+ }
3613
+ }
3614
+ }
3615
+
3543
3616
// If there is a dominating assume with the same condition as this one,
3544
3617
// then this one is redundant, and should be removed.
3545
3618
KnownBits Known (1 );
@@ -3553,10 +3626,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
3553
3626
return eraseInstFromFunction (*II);
3554
3627
}
3555
3628
3556
- // Try to extract uniformity information from the assume and optimize
3557
- // dominated uses of any variables that are established as uniform.
3558
- optimizeAssumedUniformValues (cast<AssumeInst>(II));
3559
-
3560
3629
// Update the cache of affected values for this assumption (we might be
3561
3630
// here because we just simplified the condition).
3562
3631
AC.updateAffectedValues (cast<AssumeInst>(II));
@@ -5011,116 +5080,4 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
5011
5080
return &Call;
5012
5081
}
5013
5082
5014
- // / Extract uniformity information from assume and optimize dominated uses.
5015
- // / This works with any assume pattern that establishes value uniformity.
5016
- void InstCombinerImpl::optimizeAssumedUniformValues (AssumeInst *Assume) {
5017
- Value *AssumedCondition = Assume->getArgOperand (0 );
5018
-
5019
- // Map of uniform values to their uniform constants
5020
- SmallDenseMap<Value *, Constant *> UniformValues;
5021
-
5022
- // Pattern 1: assume(icmp eq (X, C)) -> X is uniform and equals C
5023
- if (auto *ICmp = dyn_cast<ICmpInst>(AssumedCondition)) {
5024
- if (ICmp->getPredicate () == ICmpInst::ICMP_EQ) {
5025
- Value *LHS = ICmp->getOperand (0 );
5026
- Value *RHS = ICmp->getOperand (1 );
5027
-
5028
- // X == constant -> X is uniform and equals constant
5029
- if (auto *C = dyn_cast<Constant>(RHS)) {
5030
- UniformValues[LHS] = C;
5031
- } else if (auto *C = dyn_cast<Constant>(LHS)) {
5032
- UniformValues[RHS] = C;
5033
- }
5034
-
5035
- // Handle intrinsic patterns in equality comparisons
5036
- // Pattern: assume(ballot(cmp) == -1) -> cmp is uniform and true
5037
- if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(LHS)) {
5038
- if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
5039
- if (match (RHS, m_AllOnes ())) {
5040
- Value *BallotArg = IntrinsicCall->getArgOperand (0 );
5041
- if (BallotArg->getType ()->isIntegerTy (1 )) {
5042
- UniformValues[BallotArg] = ConstantInt::getTrue (BallotArg->getType ());
5043
-
5044
- // Special case: if BallotArg is an equality comparison,
5045
- // we know the operands are equal
5046
- if (auto *CmpInst = dyn_cast<ICmpInst>(BallotArg)) {
5047
- if (CmpInst->getPredicate () == ICmpInst::ICMP_EQ) {
5048
- Value *CmpLHS = CmpInst->getOperand (0 );
5049
- Value *CmpRHS = CmpInst->getOperand (1 );
5050
-
5051
- // If one operand is constant, the other is uniform and equals that constant
5052
- if (auto *C = dyn_cast<Constant>(CmpRHS)) {
5053
- UniformValues[CmpLHS] = C;
5054
- } else if (auto *C = dyn_cast<Constant>(CmpLHS)) {
5055
- UniformValues[CmpRHS] = C;
5056
- }
5057
- // TODO: Handle case where both operands are variables
5058
- }
5059
- }
5060
- }
5061
- }
5062
- } else if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_readfirstlane) {
5063
- // assume(readfirstlane(x) == c) -> x is uniform and equals c
5064
- if (auto *C = dyn_cast<Constant>(RHS)) {
5065
- Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand (0 );
5066
- UniformValues[ReadFirstLaneArg] = C;
5067
- }
5068
- }
5069
- }
5070
-
5071
- // Handle the reverse case too
5072
- if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(RHS)) {
5073
- if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
5074
- if (match (LHS, m_AllOnes ())) {
5075
- Value *BallotArg = IntrinsicCall->getArgOperand (0 );
5076
- if (BallotArg->getType ()->isIntegerTy (1 )) {
5077
- UniformValues[BallotArg] = ConstantInt::getTrue (BallotArg->getType ());
5078
- }
5079
- }
5080
- } else if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_readfirstlane) {
5081
- if (auto *C = dyn_cast<Constant>(LHS)) {
5082
- Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand (0 );
5083
- UniformValues[ReadFirstLaneArg] = C;
5084
- }
5085
- }
5086
- }
5087
- }
5088
- }
5089
-
5090
- // Pattern 2: assume(X) where X is i1 -> X is uniform and equals true
5091
- if (AssumedCondition->getType ()->isIntegerTy (1 ) && !isa<ICmpInst>(AssumedCondition)) {
5092
- UniformValues[AssumedCondition] = ConstantInt::getTrue (AssumedCondition->getType ());
5093
- }
5094
-
5095
- // Now optimize dominated uses of all discovered uniform values
5096
- for (auto &[UniformValue, UniformConstant] : UniformValues) {
5097
- SmallVector<Use *, 8 > DominatedUses;
5098
-
5099
- // Find all uses dominated by the assume
5100
- // Skip if the value doesn't have a use list (e.g., constants)
5101
- if (!UniformValue->hasUseList ())
5102
- continue ;
5103
-
5104
- for (Use &U : UniformValue->uses ()) {
5105
- Instruction *UseInst = dyn_cast<Instruction>(U.getUser ());
5106
- if (!UseInst || UseInst == Assume)
5107
- continue ;
5108
-
5109
- // Critical: Check dominance using InstCombine's infrastructure
5110
- if (isValidAssumeForContext (Assume, UseInst, &DT)) {
5111
- DominatedUses.push_back (&U);
5112
- }
5113
- }
5114
-
5115
- // Replace only dominated uses with the uniform constant
5116
- for (Use *U : DominatedUses) {
5117
- U->set (UniformConstant);
5118
- Worklist.pushValue (U->getUser ());
5119
- }
5120
-
5121
- // Mark for further optimization if we made changes
5122
- if (!DominatedUses.empty ()) {
5123
- Worklist.pushValue (UniformValue);
5124
- }
5125
- }
5126
- }
5083
+
0 commit comments