Skip to content

Commit e434695

Browse files
Simplified the logic to check for CompareValue and return nullptr when Ballot arg is modified
1 parent faabdfc commit e434695

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,9 +1345,9 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13451345

13461346
// Fold ballot intrinsic based on llvm.assume hint about the result.
13471347
//
1348-
// assume(ballot(x) == ballot(i1 true)) -> x = true
1349-
// assume(ballot(x) == -1) -> x = true
1350-
// assume(ballot(x) == 0) -> x = false
1348+
// assume(ballot(x) == ballot(true)) -> x = true
1349+
// assume(ballot(x) == -1) -> x = true
1350+
// assume(ballot(x) == 0) -> x = false
13511351
if (Arg->getType()->isIntegerTy(1)) {
13521352
for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
13531353
if (!AssumeVH)
@@ -1368,29 +1368,42 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13681368
continue;
13691369

13701370
// Determine the constant value of the ballot's condition argument.
1371-
std::optional<bool> PropagatedBool;
1372-
if (match(CompareValue, m_AllOnes()) ||
1373-
match(CompareValue,
1374-
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
1375-
// ballot(x) == -1 or ballot(x) == ballot(true) means x is true.
1376-
PropagatedBool = true;
1377-
} else if (match(CompareValue, m_Zero())) {
1378-
// ballot(x) == 0 means x is false.
1379-
PropagatedBool = false;
1371+
std::optional<bool> InferredCondValue;
1372+
if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1373+
// ballot(x) == -1 means all lanes have x = true.
1374+
if (CI->isMinusOne())
1375+
InferredCondValue = true;
1376+
// ballot(x) == 0 means all lanes have x = false.
1377+
else if (CI->isZero())
1378+
InferredCondValue = false;
1379+
} else if (match(CompareValue,
1380+
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
1381+
// ballot(x) == ballot(true) means x = true.
1382+
InferredCondValue = true;
1383+
} else if (match(CompareValue,
1384+
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero()))) {
1385+
// ballot(x) == ballot(false) means x = false.
1386+
InferredCondValue = false;
13801387
}
13811388

1382-
if (!PropagatedBool)
1389+
if (!InferredCondValue)
13831390
continue;
13841391

1385-
Constant *PropagatedValue =
1386-
ConstantInt::getBool(Arg->getContext(), *PropagatedBool);
1392+
Constant *ReplacementValue =
1393+
ConstantInt::getBool(Arg->getContext(), *InferredCondValue);
13871394

1388-
// Replace dominated uses of the ballot's condition argument with the
1389-
// propagated value.
1390-
Arg->replaceUsesWithIf(PropagatedValue, [&](Use &U) {
1395+
// Replace dominated uses of the condition argument.
1396+
bool Changed = false;
1397+
Arg->replaceUsesWithIf(ReplacementValue, [&](Use &U) {
13911398
Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
1392-
return UserInst && IC.getDominatorTree().dominates(Assume, U);
1399+
bool Dominates =
1400+
UserInst && IC.getDominatorTree().dominates(Assume, U);
1401+
Changed |= Dominates;
1402+
return Dominates;
13931403
});
1404+
1405+
if (Changed)
1406+
return nullptr;
13941407
}
13951408
}
13961409

0 commit comments

Comments
 (0)