@@ -1345,9 +1345,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13451345
13461346 // Fold ballot intrinsic based on llvm.assume hint about the result.
13471347 //
1348- // assume(ballot(x) == ballot(i1 true)) -> x = true
1349- // assume(ballot(x) == -1) -> x = true
1350- // assume(ballot(x) == 0) -> x = false
1348+ // assume(ballot(x) == ballot(true)) -> x = true
1349+ // assume(ballot(x) == -1) -> x = true
1350+ // assume(ballot(x) == 0) -> x = false
13511351 if (Arg->getType ()->isIntegerTy (1 )) {
13521352 for (auto &AssumeVH : IC.getAssumptionCache ().assumptionsFor (&II)) {
13531353 if (!AssumeVH)
@@ -1368,29 +1368,42 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13681368 continue ;
13691369
13701370 // Determine the constant value of the ballot's condition argument.
1371- std::optional<bool > PropagatedBool;
1372- if (match (CompareValue, m_AllOnes ()) ||
1373- match (CompareValue,
1374- m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One ()))) {
1375- // ballot(x) == -1 or ballot(x) == ballot(true) means x is true.
1376- PropagatedBool = true ;
1377- } else if (match (CompareValue, m_Zero ())) {
1378- // ballot(x) == 0 means x is false.
1379- PropagatedBool = false ;
1371+ std::optional<bool > InferredCondValue;
1372+ if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1373+ // ballot(x) == -1 means all lanes have x = true.
1374+ if (CI->isMinusOne ())
1375+ InferredCondValue = true ;
1376+ // ballot(x) == 0 means all lanes have x = false.
1377+ else if (CI->isZero ())
1378+ InferredCondValue = false ;
1379+ } else if (match (CompareValue,
1380+ m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One ()))) {
1381+ // ballot(x) == ballot(true) means x = true.
1382+ InferredCondValue = true ;
1383+ } else if (match (CompareValue,
1384+ m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero ()))) {
1385+ // ballot(x) == ballot(false) means x = false.
1386+ InferredCondValue = false ;
13801387 }
13811388
1382- if (!PropagatedBool )
1389+ if (!InferredCondValue )
13831390 continue ;
13841391
1385- Constant *PropagatedValue =
1386- ConstantInt::getBool (Arg->getContext (), *PropagatedBool );
1392+ Constant *ReplacementValue =
1393+ ConstantInt::getBool (Arg->getContext (), *InferredCondValue );
13871394
1388- // Replace dominated uses of the ballot's condition argument with the
1389- // propagated value.
1390- Arg->replaceUsesWithIf (PropagatedValue , [&](Use &U) {
1395+ // Replace dominated uses of the condition argument.
1396+ bool Changed = false ;
1397+ Arg->replaceUsesWithIf (ReplacementValue , [&](Use &U) {
13911398 Instruction *UserInst = dyn_cast<Instruction>(U.getUser ());
1392- return UserInst && IC.getDominatorTree ().dominates (Assume, U);
1399+ bool Dominates =
1400+ UserInst && IC.getDominatorTree ().dominates (Assume, U);
1401+ Changed |= Dominates;
1402+ return Dominates;
13931403 });
1404+
1405+ if (Changed)
1406+ return nullptr ;
13941407 }
13951408 }
13961409
0 commit comments