@@ -1348,63 +1348,61 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
13481348 // assume(ballot(x) == ballot(true)) -> x = true
13491349 // assume(ballot(x) == -1) -> x = true
13501350 // assume(ballot(x) == 0) -> x = false
1351- if (Arg->getType ()->isIntegerTy (1 )) {
1352- for (auto &AssumeVH : IC.getAssumptionCache ().assumptionsFor (&II)) {
1353- if (!AssumeVH)
1354- continue ;
1355-
1356- auto *Assume = cast<AssumeInst>(AssumeVH);
1357- Value *Cond = Assume->getArgOperand (0 );
1358-
1359- // Check if assume condition is an equality comparison.
1360- auto *ICI = dyn_cast<ICmpInst>(Cond);
1361- if (!ICI || ICI->getPredicate () != ICmpInst::ICMP_EQ)
1362- continue ;
1363-
1364- // Extract the ballot and the value being compared against it.
1365- Value *LHS = ICI->getOperand (0 ), *RHS = ICI->getOperand (1 );
1366- Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr ;
1367- if (!CompareValue)
1368- continue ;
1369-
1370- // Determine the constant value of the ballot's condition argument.
1371- std::optional<bool > InferredCondValue;
1372- if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1373- // ballot(x) == -1 means all lanes have x = true.
1374- if (CI->isMinusOne ())
1375- InferredCondValue = true ;
1376- // ballot(x) == 0 means all lanes have x = false.
1377- else if (CI->isZero ())
1378- InferredCondValue = false ;
1379- } else if (match (CompareValue,
1380- m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One ()))) {
1381- // ballot(x) == ballot(true) means x = true.
1351+ //
1352+ // Skip if ballot width < wave size (e.g., ballot.i32 on wave64).
1353+ if (ST->isWave64 () && II.getType ()->getIntegerBitWidth () == 32 )
1354+ break ; // ballot.i32 on wave64 captures only lanes [0:31]
1355+
1356+ for (auto &AssumeVH : IC.getAssumptionCache ().assumptionsFor (&II)) {
1357+ if (!AssumeVH)
1358+ continue ;
1359+
1360+ auto *Assume = cast<AssumeInst>(AssumeVH);
1361+ Value *Cond = Assume->getArgOperand (0 );
1362+
1363+ // Check if assume condition is an equality comparison.
1364+ auto *ICI = dyn_cast<ICmpInst>(Cond);
1365+ if (!ICI || ICI->getPredicate () != ICmpInst::ICMP_EQ)
1366+ continue ;
1367+
1368+ // Extract the ballot and the value being compared against it.
1369+ Value *LHS = ICI->getOperand (0 ), *RHS = ICI->getOperand (1 );
1370+ Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr ;
1371+ if (!CompareValue)
1372+ continue ;
1373+
1374+ // Determine the constant value of the ballot's condition argument.
1375+ std::optional<bool > InferredCondValue;
1376+ if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
1377+ // ballot(x) == -1 means all lanes have x = true.
1378+ if (CI->isMinusOne ())
13821379 InferredCondValue = true ;
1383- } else if (match (CompareValue,
1384- m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero ()))) {
1385- // ballot(x) == ballot(false) means x = false.
1380+ // ballot(x) == 0 means all lanes have x = false.
1381+ else if (CI->isZero ())
13861382 InferredCondValue = false ;
1387- }
1383+ } else if (match (CompareValue,
1384+ m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One ()))) {
1385+ // ballot(x) == ballot(true) means x = true (EXEC mask comparison).
1386+ InferredCondValue = true ;
1387+ }
13881388
1389- if (!InferredCondValue)
1390- continue ;
1389+ if (!InferredCondValue)
1390+ continue ;
13911391
1392- Constant *ReplacementValue =
1393- ConstantInt::getBool (Arg->getContext (), *InferredCondValue);
1392+ Constant *ReplacementValue =
1393+ ConstantInt::getBool (Arg->getContext (), *InferredCondValue);
13941394
1395- // Replace dominated uses of the condition argument.
1396- bool Changed = false ;
1397- Arg->replaceUsesWithIf (ReplacementValue, [&](Use &U) {
1398- Instruction *UserInst = dyn_cast<Instruction>(U.getUser ());
1399- bool Dominates =
1400- UserInst && IC.getDominatorTree ().dominates (Assume, U);
1401- Changed |= Dominates;
1402- return Dominates;
1403- });
1395+ // Replace uses of the condition argument dominated by the assume.
1396+ bool Changed = false ;
1397+ Arg->replaceUsesWithIf (ReplacementValue, [&](Use &U) {
1398+ Instruction *UserInst = dyn_cast<Instruction>(U.getUser ());
1399+ bool Dominates = UserInst && IC.getDominatorTree ().dominates (Assume, U);
1400+ Changed |= Dominates;
1401+ return Dominates;
1402+ });
14041403
1405- if (Changed)
1406- return nullptr ;
1407- }
1404+ if (Changed)
1405+ return nullptr ;
14081406 }
14091407
14101408 break ;
0 commit comments