Skip to content

Commit e60ca86

Browse files
authored
[AMDGPU] Refine GCNHazardRecognizer hasHazard() (#138841)
Remove recursion to avoid stack overflow on large CFGs. Avoid worklist for hazard search within single MachineBasicBlock. Ensure predecessors are visited for all state combinations.
1 parent 1359f3a commit e60ca86

File tree

1 file changed

+106
-31
lines changed

1 file changed

+106
-31
lines changed

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -443,40 +443,101 @@ using GetNumWaitStatesFn = function_ref<unsigned int(const MachineInstr &)>;
443443
// Search for a hazard in a block and its predecessors.
444444
template <typename StateT>
445445
static bool
446-
hasHazard(StateT State,
446+
hasHazard(StateT InitialState,
447447
function_ref<HazardFnResult(StateT &, const MachineInstr &)> IsHazard,
448448
function_ref<void(StateT &, const MachineInstr &)> UpdateState,
449-
const MachineBasicBlock *MBB,
450-
MachineBasicBlock::const_reverse_instr_iterator I,
451-
DenseSet<const MachineBasicBlock *> &Visited) {
452-
for (auto E = MBB->instr_rend(); I != E; ++I) {
453-
// No need to look at parent BUNDLE instructions.
454-
if (I->isBundle())
455-
continue;
456-
457-
switch (IsHazard(State, *I)) {
458-
case HazardFound:
459-
return true;
460-
case HazardExpired:
461-
return false;
462-
default:
463-
// Continue search
464-
break;
449+
const MachineBasicBlock *InitialMBB,
450+
MachineBasicBlock::const_reverse_instr_iterator InitialI) {
451+
struct StateMapKey {
452+
SmallVectorImpl<StateT> *States;
453+
unsigned Idx;
454+
static bool isEqual(const StateMapKey &LHS, const StateMapKey &RHS) {
455+
return LHS.States == RHS.States && LHS.Idx == RHS.Idx;
456+
}
457+
};
458+
struct StateMapKeyTraits : DenseMapInfo<StateMapKey> {
459+
static inline StateMapKey getEmptyKey() {
460+
return {static_cast<SmallVectorImpl<StateT> *>(
461+
DenseMapInfo<void *>::getEmptyKey()),
462+
DenseMapInfo<unsigned>::getEmptyKey()};
463+
}
464+
static inline StateMapKey getTombstoneKey() {
465+
return {static_cast<SmallVectorImpl<StateT> *>(
466+
DenseMapInfo<void *>::getTombstoneKey()),
467+
DenseMapInfo<unsigned>::getTombstoneKey()};
468+
}
469+
static unsigned getHashValue(const StateMapKey &Key) {
470+
return StateT::getHashValue((*Key.States)[Key.Idx]);
465471
}
472+
static unsigned getHashValue(const StateT &State) {
473+
return StateT::getHashValue(State);
474+
}
475+
static bool isEqual(const StateMapKey &LHS, const StateMapKey &RHS) {
476+
const auto EKey = getEmptyKey();
477+
const auto TKey = getTombstoneKey();
478+
if (StateMapKey::isEqual(LHS, EKey) || StateMapKey::isEqual(RHS, EKey) ||
479+
StateMapKey::isEqual(LHS, TKey) || StateMapKey::isEqual(RHS, TKey))
480+
return StateMapKey::isEqual(LHS, RHS);
481+
return StateT::isEqual((*LHS.States)[LHS.Idx], (*RHS.States)[RHS.Idx]);
482+
}
483+
static bool isEqual(const StateT &LHS, const StateMapKey &RHS) {
484+
if (StateMapKey::isEqual(RHS, getEmptyKey()) ||
485+
StateMapKey::isEqual(RHS, getTombstoneKey()))
486+
return false;
487+
return StateT::isEqual(LHS, (*RHS.States)[RHS.Idx]);
488+
}
489+
};
466490

467-
if (I->isInlineAsm() || I->isMetaInstruction())
468-
continue;
491+
SmallDenseMap<StateMapKey, unsigned, 8, StateMapKeyTraits> StateMap;
492+
SmallVector<StateT, 8> States;
469493

470-
UpdateState(State, *I);
471-
}
494+
MachineBasicBlock::const_reverse_instr_iterator I = InitialI;
495+
const MachineBasicBlock *MBB = InitialMBB;
496+
StateT State = InitialState;
472497

473-
for (MachineBasicBlock *Pred : MBB->predecessors()) {
474-
if (!Visited.insert(Pred).second)
475-
continue;
498+
SmallSetVector<std::pair<const MachineBasicBlock *, unsigned>, 16> Worklist;
499+
unsigned WorkIdx = 0;
500+
for (;;) {
501+
bool Expired = false;
502+
for (auto E = MBB->instr_rend(); I != E; ++I) {
503+
// No need to look at parent BUNDLE instructions.
504+
if (I->isBundle())
505+
continue;
476506

477-
if (hasHazard(State, IsHazard, UpdateState, Pred, Pred->instr_rbegin(),
478-
Visited))
479-
return true;
507+
auto Result = IsHazard(State, *I);
508+
if (Result == HazardFound)
509+
return true;
510+
if (Result == HazardExpired) {
511+
Expired = true;
512+
break;
513+
}
514+
515+
if (I->isInlineAsm() || I->isMetaInstruction())
516+
continue;
517+
518+
UpdateState(State, *I);
519+
}
520+
521+
if (!Expired) {
522+
unsigned StateIdx = States.size();
523+
StateMapKey Key = {&States, StateIdx};
524+
auto Insertion = StateMap.insert_as(std::pair(Key, StateIdx), State);
525+
if (Insertion.second) {
526+
States.emplace_back(State);
527+
} else {
528+
StateIdx = Insertion.first->second;
529+
}
530+
for (MachineBasicBlock *Pred : MBB->predecessors())
531+
Worklist.insert(std::pair(Pred, StateIdx));
532+
}
533+
534+
if (WorkIdx == Worklist.size())
535+
break;
536+
537+
unsigned StateIdx;
538+
std::tie(MBB, StateIdx) = Worklist[WorkIdx++];
539+
State = States[StateIdx];
540+
I = MBB->instr_rbegin();
480541
}
481542

482543
return false;
@@ -1641,6 +1702,15 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
16411702
SmallDenseMap<Register, int, 4> DefPos;
16421703
int ExecPos = std::numeric_limits<int>::max();
16431704
int VALUs = 0;
1705+
1706+
static unsigned getHashValue(const StateType &State) {
1707+
return hash_combine(State.ExecPos, State.VALUs,
1708+
hash_combine_range(State.DefPos));
1709+
}
1710+
static bool isEqual(const StateType &LHS, const StateType &RHS) {
1711+
return LHS.DefPos == RHS.DefPos && LHS.ExecPos == RHS.ExecPos &&
1712+
LHS.VALUs == RHS.VALUs;
1713+
}
16441714
};
16451715

16461716
StateType State;
@@ -1735,9 +1805,8 @@ bool GCNHazardRecognizer::fixVALUPartialForwardingHazard(MachineInstr *MI) {
17351805
State.VALUs += 1;
17361806
};
17371807

1738-
DenseSet<const MachineBasicBlock *> Visited;
17391808
if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
1740-
std::next(MI->getReverseIterator()), Visited))
1809+
std::next(MI->getReverseIterator())))
17411810
return false;
17421811

17431812
BuildMI(*MI->getParent(), MI, MI->getDebugLoc(),
@@ -1778,6 +1847,13 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
17781847
struct StateType {
17791848
int VALUs = 0;
17801849
int TRANS = 0;
1850+
1851+
static unsigned getHashValue(const StateType &State) {
1852+
return hash_combine(State.VALUs, State.TRANS);
1853+
}
1854+
static bool isEqual(const StateType &LHS, const StateType &RHS) {
1855+
return LHS.VALUs == RHS.VALUs && LHS.TRANS == RHS.TRANS;
1856+
}
17811857
};
17821858

17831859
StateType State;
@@ -1813,9 +1889,8 @@ bool GCNHazardRecognizer::fixVALUTransUseHazard(MachineInstr *MI) {
18131889
State.TRANS += 1;
18141890
};
18151891

1816-
DenseSet<const MachineBasicBlock *> Visited;
18171892
if (!hasHazard<StateType>(State, IsHazardFn, UpdateStateFn, MI->getParent(),
1818-
std::next(MI->getReverseIterator()), Visited))
1893+
std::next(MI->getReverseIterator())))
18191894
return false;
18201895

18211896
// Hazard is observed - insert a wait on va_dst counter to ensure hazard is

0 commit comments

Comments
 (0)