Skip to content

Commit 0117763

Browse files
committed
Change the conditon to check for convergence token kind
1 parent 72d5693 commit 0117763

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

llvm/lib/Transforms/Utils/BasicBlockUtils.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ emptyAndDetachBlock(BasicBlock *BB,
9292
"applying corresponding DTU updates.");
9393
}
9494

95+
static bool HasLoopOrEntryConvergenceToken(const BasicBlock *BB) {
96+
for (const Instruction &I : *BB) {
97+
const ConvergenceControlInst *CCI = dyn_cast<ConvergenceControlInst>(&I);
98+
if (CCI && (CCI->isLoop() || CCI->isEntry()))
99+
return true;
100+
}
101+
return false;
102+
}
103+
95104
void llvm::detachDeadBlocks(ArrayRef<BasicBlock *> BBs,
96105
SmallVectorImpl<DominatorTree::UpdateType> *Updates,
97106
bool KeepOneInputPHIs) {
@@ -230,15 +239,6 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
230239
// Don't break self-loops.
231240
if (PredBB == BB) return false;
232241

233-
// Don't break if both the basic block and the predecessor contain convergent
234-
// intrinsics.
235-
for (Instruction &I : *BB)
236-
if (isa<ConvergenceControlInst>(I)) {
237-
for (Instruction &I : *PredBB)
238-
if (isa<ConvergenceControlInst>(I))
239-
return false;
240-
}
241-
242242
// Don't break unwinding instructions or terminators with other side-effects.
243243
Instruction *PTI = PredBB->getTerminator();
244244
if (PTI->isSpecialTerminator() || PTI->mayHaveSideEffects())
@@ -268,6 +268,13 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU,
268268
if (llvm::is_contained(PN.incoming_values(), &PN))
269269
return false;
270270

271+
// Don't break if both the basic block and the predecessor contain loop or
272+
// entry convergent intrinsics, since there may only be one convergence token
273+
// per block.
274+
if (HasLoopOrEntryConvergenceToken(BB) &&
275+
HasLoopOrEntryConvergenceToken(PredBB))
276+
return false;
277+
271278
LLVM_DEBUG(dbgs() << "Merging: " << BB->getName() << " into "
272279
<< PredBB->getName() << "\n");
273280

0 commit comments

Comments
 (0)