@@ -3549,6 +3549,79 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35493549 }
35503550 }
35513551
3552+ // Basic assume equality optimization: assume(x == c) -> replace dominated uses of x with c
3553+ if (auto *ICmp = dyn_cast<ICmpInst>(IIOperand)) {
3554+ if (ICmp->getPredicate () == ICmpInst::ICMP_EQ) {
3555+ Value *LHS = ICmp->getOperand (0 );
3556+ Value *RHS = ICmp->getOperand (1 );
3557+ Value *Variable = nullptr ;
3558+ Constant *ConstantVal = nullptr ;
3559+
3560+ if (auto *C = dyn_cast<Constant>(RHS)) {
3561+ Variable = LHS;
3562+ ConstantVal = C;
3563+ } else if (auto *C = dyn_cast<Constant>(LHS)) {
3564+ Variable = RHS;
3565+ ConstantVal = C;
3566+ }
3567+
3568+ if (Variable && ConstantVal && Variable->hasUseList ()) {
3569+ SmallVector<Use *, 8 > DominatedUses;
3570+ for (Use &U : Variable->uses ()) {
3571+ if (auto *UseInst = dyn_cast<Instruction>(U.getUser ())) {
3572+ if (UseInst != II && UseInst != ICmp &&
3573+ isValidAssumeForContext (II, UseInst, &DT)) {
3574+ DominatedUses.push_back (&U);
3575+ }
3576+ }
3577+ }
3578+
3579+ for (Use *U : DominatedUses) {
3580+ U->set (ConstantVal);
3581+ Worklist.pushValue (U->getUser ());
3582+ }
3583+
3584+ if (!DominatedUses.empty ()) {
3585+ Worklist.pushValue (Variable);
3586+ }
3587+ }
3588+ }
3589+ }
3590+
3591+ // Optimize AMDGPU ballot patterns in assumes:
3592+ // assume(ballot(cmp) == -1) means cmp is true on all active lanes
3593+ // We can replace uses of cmp with true in dominated contexts
3594+ Value *BallotInst;
3595+ if (match (IIOperand, m_SpecificICmp (ICmpInst::ICMP_EQ, m_Value (BallotInst), m_AllOnes ()))) {
3596+ if (auto *IntrCall = dyn_cast<IntrinsicInst>(BallotInst)) {
3597+ if (IntrCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
3598+ Value *BallotArg = IntrCall->getArgOperand (0 );
3599+ if (BallotArg->getType ()->isIntegerTy (1 ) && BallotArg->hasUseList ()) {
3600+ // Find dominated uses and replace with true
3601+ SmallVector<Use *, 8 > DominatedUses;
3602+ for (Use &U : BallotArg->uses ()) {
3603+ if (auto *UseInst = dyn_cast<Instruction>(U.getUser ())) {
3604+ if (UseInst != II && UseInst != IntrCall &&
3605+ isValidAssumeForContext (II, UseInst, &DT)) {
3606+ DominatedUses.push_back (&U);
3607+ }
3608+ }
3609+ }
3610+
3611+ // Replace dominated uses with true
3612+ for (Use *U : DominatedUses) {
3613+ U->set (ConstantInt::getTrue (BallotArg->getType ()));
3614+ Worklist.pushValue (U->getUser ());
3615+ }
3616+
3617+ if (!DominatedUses.empty ()) {
3618+ Worklist.pushValue (BallotArg);
3619+ }
3620+ }
3621+ }
3622+ }
3623+ }
3624+
35523625 // If there is a dominating assume with the same condition as this one,
35533626 // then this one is redundant, and should be removed.
35543627 KnownBits Known (1 );
@@ -3562,10 +3635,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
35623635 return eraseInstFromFunction (*II);
35633636 }
35643637
3565- // Try to extract uniformity information from the assume and optimize
3566- // dominated uses of any variables that are established as uniform.
3567- optimizeAssumedUniformValues (cast<AssumeInst>(II));
3568-
35693638 // Update the cache of affected values for this assumption (we might be
35703639 // here because we just simplified the condition).
35713640 AC.updateAffectedValues (cast<AssumeInst>(II));
@@ -5031,116 +5100,4 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
50315100 return &Call;
50325101}
50335102
5034- // / Extract uniformity information from assume and optimize dominated uses.
5035- // / This works with any assume pattern that establishes value uniformity.
5036- void InstCombinerImpl::optimizeAssumedUniformValues (AssumeInst *Assume) {
5037- Value *AssumedCondition = Assume->getArgOperand (0 );
5038-
5039- // Map of uniform values to their uniform constants
5040- SmallDenseMap<Value *, Constant *> UniformValues;
5041-
5042- // Pattern 1: assume(icmp eq (X, C)) -> X is uniform and equals C
5043- if (auto *ICmp = dyn_cast<ICmpInst>(AssumedCondition)) {
5044- if (ICmp->getPredicate () == ICmpInst::ICMP_EQ) {
5045- Value *LHS = ICmp->getOperand (0 );
5046- Value *RHS = ICmp->getOperand (1 );
5047-
5048- // X == constant -> X is uniform and equals constant
5049- if (auto *C = dyn_cast<Constant>(RHS)) {
5050- UniformValues[LHS] = C;
5051- } else if (auto *C = dyn_cast<Constant>(LHS)) {
5052- UniformValues[RHS] = C;
5053- }
5054-
5055- // Handle intrinsic patterns in equality comparisons
5056- // Pattern: assume(ballot(cmp) == -1) -> cmp is uniform and true
5057- if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(LHS)) {
5058- if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
5059- if (match (RHS, m_AllOnes ())) {
5060- Value *BallotArg = IntrinsicCall->getArgOperand (0 );
5061- if (BallotArg->getType ()->isIntegerTy (1 )) {
5062- UniformValues[BallotArg] = ConstantInt::getTrue (BallotArg->getType ());
5063-
5064- // Special case: if BallotArg is an equality comparison,
5065- // we know the operands are equal
5066- if (auto *CmpInst = dyn_cast<ICmpInst>(BallotArg)) {
5067- if (CmpInst->getPredicate () == ICmpInst::ICMP_EQ) {
5068- Value *CmpLHS = CmpInst->getOperand (0 );
5069- Value *CmpRHS = CmpInst->getOperand (1 );
5070-
5071- // If one operand is constant, the other is uniform and equals that constant
5072- if (auto *C = dyn_cast<Constant>(CmpRHS)) {
5073- UniformValues[CmpLHS] = C;
5074- } else if (auto *C = dyn_cast<Constant>(CmpLHS)) {
5075- UniformValues[CmpRHS] = C;
5076- }
5077- // TODO: Handle case where both operands are variables
5078- }
5079- }
5080- }
5081- }
5082- } else if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_readfirstlane) {
5083- // assume(readfirstlane(x) == c) -> x is uniform and equals c
5084- if (auto *C = dyn_cast<Constant>(RHS)) {
5085- Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand (0 );
5086- UniformValues[ReadFirstLaneArg] = C;
5087- }
5088- }
5089- }
5090-
5091- // Handle the reverse case too
5092- if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(RHS)) {
5093- if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_ballot) {
5094- if (match (LHS, m_AllOnes ())) {
5095- Value *BallotArg = IntrinsicCall->getArgOperand (0 );
5096- if (BallotArg->getType ()->isIntegerTy (1 )) {
5097- UniformValues[BallotArg] = ConstantInt::getTrue (BallotArg->getType ());
5098- }
5099- }
5100- } else if (IntrinsicCall->getIntrinsicID () == Intrinsic::amdgcn_readfirstlane) {
5101- if (auto *C = dyn_cast<Constant>(LHS)) {
5102- Value *ReadFirstLaneArg = IntrinsicCall->getArgOperand (0 );
5103- UniformValues[ReadFirstLaneArg] = C;
5104- }
5105- }
5106- }
5107- }
5108- }
5109-
5110- // Pattern 2: assume(X) where X is i1 -> X is uniform and equals true
5111- if (AssumedCondition->getType ()->isIntegerTy (1 ) && !isa<ICmpInst>(AssumedCondition)) {
5112- UniformValues[AssumedCondition] = ConstantInt::getTrue (AssumedCondition->getType ());
5113- }
5114-
5115- // Now optimize dominated uses of all discovered uniform values
5116- for (auto &[UniformValue, UniformConstant] : UniformValues) {
5117- SmallVector<Use *, 8 > DominatedUses;
5118-
5119- // Find all uses dominated by the assume
5120- // Skip if the value doesn't have a use list (e.g., constants)
5121- if (!UniformValue->hasUseList ())
5122- continue ;
5123-
5124- for (Use &U : UniformValue->uses ()) {
5125- Instruction *UseInst = dyn_cast<Instruction>(U.getUser ());
5126- if (!UseInst || UseInst == Assume)
5127- continue ;
5128-
5129- // Critical: Check dominance using InstCombine's infrastructure
5130- if (isValidAssumeForContext (Assume, UseInst, &DT)) {
5131- DominatedUses.push_back (&U);
5132- }
5133- }
5134-
5135- // Replace only dominated uses with the uniform constant
5136- for (Use *U : DominatedUses) {
5137- U->set (UniformConstant);
5138- Worklist.pushValue (U->getUser ());
5139- }
5140-
5141- // Mark for further optimization if we made changes
5142- if (!DominatedUses.empty ()) {
5143- Worklist.pushValue (UniformValue);
5144- }
5145- }
5146- }
5103+
0 commit comments