Skip to content

Commit 08e5759

Browse files
Skip the optimization when ballot size < wave size
1 parent 5ddb410 commit 08e5759

File tree

3 files changed

+599
-181
lines changed

3 files changed

+599
-181
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,63 +1348,61 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13481348
// assume(ballot(x) == ballot(true)) -> x = true
13491349
// assume(ballot(x) == -1) -> x = true
13501350
// assume(ballot(x) == 0) -> x = false
1351-
if (Arg->getType()->isIntegerTy(1)) {
1352-
for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
1353-
if (!AssumeVH)
1354-
continue;
1355-
1356-
auto *Assume = cast<AssumeInst>(AssumeVH);
1357-
Value *Cond = Assume->getArgOperand(0);
1358-
1359-
// Check if assume condition is an equality comparison.
1360-
auto *ICI = dyn_cast<ICmpInst>(Cond);
1361-
if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
1362-
continue;
1363-
1364-
// Extract the ballot and the value being compared against it.
1365-
Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
1366-
Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
1367-
if (!CompareValue)
1368-
continue;
1369-
1370-
// Determine the constant value of the ballot's condition argument.
1371-
std::optional<bool> InferredCondValue;
1372-
if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1373-
// ballot(x) == -1 means all lanes have x = true.
1374-
if (CI->isMinusOne())
1375-
InferredCondValue = true;
1376-
// ballot(x) == 0 means all lanes have x = false.
1377-
else if (CI->isZero())
1378-
InferredCondValue = false;
1379-
} else if (match(CompareValue,
1380-
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
1381-
// ballot(x) == ballot(true) means x = true.
1351+
//
1352+
// Skip if ballot width < wave size (e.g., ballot.i32 on wave64).
1353+
if (ST->isWave64() && II.getType()->getIntegerBitWidth() == 32)
1354+
break; // ballot.i32 on wave64 captures only lanes [0:31]
1355+
1356+
for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
1357+
if (!AssumeVH)
1358+
continue;
1359+
1360+
auto *Assume = cast<AssumeInst>(AssumeVH);
1361+
Value *Cond = Assume->getArgOperand(0);
1362+
1363+
// Check if assume condition is an equality comparison.
1364+
auto *ICI = dyn_cast<ICmpInst>(Cond);
1365+
if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
1366+
continue;
1367+
1368+
// Extract the ballot and the value being compared against it.
1369+
Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
1370+
Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
1371+
if (!CompareValue)
1372+
continue;
1373+
1374+
// Determine the constant value of the ballot's condition argument.
1375+
std::optional<bool> InferredCondValue;
1376+
if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1377+
// ballot(x) == -1 means all lanes have x = true.
1378+
if (CI->isMinusOne())
13821379
InferredCondValue = true;
1383-
} else if (match(CompareValue,
1384-
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero()))) {
1385-
// ballot(x) == ballot(false) means x = false.
1380+
// ballot(x) == 0 means all lanes have x = false.
1381+
else if (CI->isZero())
13861382
InferredCondValue = false;
1387-
}
1383+
} else if (match(CompareValue,
1384+
m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
1385+
// ballot(x) == ballot(true) means x = true (EXEC mask comparison).
1386+
InferredCondValue = true;
1387+
}
13881388

1389-
if (!InferredCondValue)
1390-
continue;
1389+
if (!InferredCondValue)
1390+
continue;
13911391

1392-
Constant *ReplacementValue =
1393-
ConstantInt::getBool(Arg->getContext(), *InferredCondValue);
1392+
Constant *ReplacementValue =
1393+
ConstantInt::getBool(Arg->getContext(), *InferredCondValue);
13941394

1395-
// Replace dominated uses of the condition argument.
1396-
bool Changed = false;
1397-
Arg->replaceUsesWithIf(ReplacementValue, [&](Use &U) {
1398-
Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
1399-
bool Dominates =
1400-
UserInst && IC.getDominatorTree().dominates(Assume, U);
1401-
Changed |= Dominates;
1402-
return Dominates;
1403-
});
1395+
// Replace uses of the condition argument dominated by the assume.
1396+
bool Changed = false;
1397+
Arg->replaceUsesWithIf(ReplacementValue, [&](Use &U) {
1398+
Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
1399+
bool Dominates = UserInst && IC.getDominatorTree().dominates(Assume, U);
1400+
Changed |= Dominates;
1401+
return Dominates;
1402+
});
14041403

1405-
if (Changed)
1406-
return nullptr;
1407-
}
1404+
if (Changed)
1405+
return nullptr;
14081406
}
14091407

14101408
break;

0 commit comments

Comments
 (0)