@@ -2540,18 +2540,60 @@ bool GVNPass::propagateEquality(
25402540 }
25412541 }
25422542
2543- // If "ballot(cond) == -1" or "ballot(cond) == exec_mask" then cond is true
2544- // on all active lanes, so cond can be replaced with true.
2545- if (IntrinsicInst *IntrCall = dyn_cast<IntrinsicInst>(LHS)) {
2546- if (IntrCall->getIntrinsicID () ==
2547- Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2548- Value *BallotArg = IntrCall->getArgOperand (0 );
2549- if (BallotArg->getType ()->isIntegerTy (1 ) &&
2550- (match (RHS, m_AllOnes ()) || !isa<Constant>(RHS))) {
2543+ // Helper function to check if a value represents the current exec mask.
2544+ auto IsExecMask = [](Value *V) -> bool {
2545+ // Pattern 1: ballot(true)
2546+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(V)) {
2547+ if (II->getIntrinsicID () == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2548+ // Check if argument is constant true
2549+ if (match (II->getArgOperand (0 ), m_One ())) {
2550+ return true ;
2551+ }
2552+ }
2553+ }
2554+
2555+ return false ;
2556+ };
2557+
2558+ // Check if either of the operands is a ballot intrinsic.
2559+ IntrinsicInst *BallotCall = nullptr ;
2560+ Value *CompareValue = nullptr ;
2561+
2562+ // Check both LHS and RHS for ballot intrinsic and its value since GVN may
2563+ // swap the operands.
2564+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHS)) {
2565+ if (II->getIntrinsicID () == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2566+ BallotCall = II;
2567+ CompareValue = RHS;
2568+ }
2569+ }
2570+ if (!BallotCall && isa<IntrinsicInst>(RHS)) {
2571+ IntrinsicInst *II = cast<IntrinsicInst>(RHS);
2572+ if (II->getIntrinsicID () == Intrinsic::AMDGCNIntrinsics::amdgcn_ballot) {
2573+ BallotCall = II;
2574+ CompareValue = LHS;
2575+ }
2576+ }
2577+
2578+ // If a ballot intrinsic is found, calculate the truth value of the ballot
2579+ // argument based on the RHS.
2580+ if (BallotCall) {
2581+ Value *BallotArg = BallotCall->getArgOperand (0 );
2582+ if (BallotArg->getType ()->isIntegerTy (1 )) {
2583+ // Case 1: ballot(cond) == -1: cond true in all lanes -> cond = true.
2584+ // Case 2: ballot(cond) == exec_mask: cond true in all active lanes ->
2585+ // cond = true.
2586+ if (match (CompareValue, m_AllOnes ()) || IsExecMask (CompareValue)) {
25512587 Worklist.push_back (std::make_pair (
25522588 BallotArg, ConstantInt::getTrue (BallotArg->getType ())));
25532589 continue ;
25542590 }
2591+ // Case 3: ballot(cond) == 0: cond false in all lanes -> cond = false.
2592+ if (match (CompareValue, m_Zero ())) {
2593+ Worklist.push_back (std::make_pair (
2594+ BallotArg, ConstantInt::getFalse (BallotArg->getType ())));
2595+ continue ;
2596+ }
25552597 }
25562598 }
25572599
0 commit comments