diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp index cc39fd1740683..fedd57b7441d4 100644 --- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp +++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp @@ -495,21 +495,18 @@ hasHazard(StateT State, return false; } -// Returns a minimum wait states since \p I walking all predecessors. -// Only scans until \p IsExpired does not return true. -// Can only be run in a hazard recognizer mode. -static int getWaitStatesSince( +// Update \p WaitStates while iterating from \p I to hazard in \p MBB. +static HazardFnResult countWaitStatesSince( GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB, - MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates, - IsExpiredFn IsExpired, DenseSet &Visited, - GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) { + MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates, + IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) { for (auto E = MBB->instr_rend(); I != E; ++I) { // Don't add WaitStates for parent BUNDLE instructions. if (I->isBundle()) continue; if (IsHazard(*I)) - return WaitStates; + return HazardFound; if (I->isInlineAsm()) continue; @@ -517,29 +514,91 @@ static int getWaitStatesSince( WaitStates += GetNumWaitStates(*I); if (IsExpired(*I, WaitStates)) - return std::numeric_limits::max(); + return HazardExpired; } + return NoHazardFound; +} + +// Implements predecessor search for getWaitStatesSince. +static int getWaitStatesSinceImpl( + GCNHazardRecognizer::IsHazardFn IsHazard, + const MachineBasicBlock *InitialMBB, int InitialWaitStates, + IsExpiredFn IsExpired, + GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) { + DenseMap Visited; + + // Build worklist of predecessors. + // Note: use queue so search is breadth first, which reduces search space + // when a hazard is found. + SmallVector Worklist; + for (MachineBasicBlock *Pred : InitialMBB->predecessors()) { + Visited[Pred] = InitialWaitStates; + Worklist.push_back(Pred); + } + + // Find minimum wait states to hazard or determine that all paths expire. int MinWaitStates = std::numeric_limits::max(); - for (MachineBasicBlock *Pred : MBB->predecessors()) { - if (!Visited.insert(Pred).second) - continue; + unsigned Idx = 0; + while (Idx < Worklist.size()) { + const MachineBasicBlock *MBB = Worklist[Idx++]; + int WaitStates = Visited[MBB]; + + // Make sure that worklist capacity is reused in large CFGs. + if (Idx >= 1024) { + Worklist.erase(Worklist.begin(), Worklist.begin() + (Idx - 1)); + Idx = 0; + } - int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates, - IsExpired, Visited, GetNumWaitStates); + // No reason to search blocks when wait states exceed established minimum. + if (WaitStates >= MinWaitStates) + continue; - MinWaitStates = std::min(MinWaitStates, W); + // Search for hazard + auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(), + WaitStates, IsExpired, GetNumWaitStates); + if (Search == HazardFound) { + // Update minimum. + MinWaitStates = std::min(MinWaitStates, WaitStates); + } else if (Search == NoHazardFound && WaitStates < MinWaitStates) { + // Search predecessors. + for (MachineBasicBlock *Pred : MBB->predecessors()) { + if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) { + // Store lowest wait states required to visit this block. + Visited[Pred] = WaitStates; + Worklist.push_back(Pred); + } + } + } } return MinWaitStates; } +// Returns minimum wait states since \p I walking all predecessors. +// Only scans until \p IsExpired does not return true. +// Can only be run in a hazard recognizer mode. +static int getWaitStatesSince( + GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB, + MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates, + IsExpiredFn IsExpired, + GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) { + // Scan this block from I. + auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates, + IsExpired, GetNumWaitStates); + if (InitSearch == HazardFound) + return WaitStates; + if (InitSearch == HazardExpired) + return std::numeric_limits::max(); + + return getWaitStatesSinceImpl(IsHazard, MBB, WaitStates, IsExpired, + GetNumWaitStates); +} + static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard, const MachineInstr *MI, IsExpiredFn IsExpired) { - DenseSet Visited; return getWaitStatesSince(IsHazard, MI->getParent(), - std::next(MI->getReverseIterator()), - 0, IsExpired, Visited); + std::next(MI->getReverseIterator()), 0, IsExpired); } int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) { @@ -1524,10 +1583,9 @@ bool GCNHazardRecognizer::fixLdsDirectVALUHazard(MachineInstr *MI) { return SIInstrInfo::isVALU(MI) ? 1 : 0; }; - DenseSet Visited; auto Count = ::getWaitStatesSince(IsHazardFn, MI->getParent(), std::next(MI->getReverseIterator()), 0, - IsExpiredFn, Visited, GetWaitStatesFn); + IsExpiredFn, GetWaitStatesFn); // Transcendentals can execute in parallel to other VALUs. // This makes va_vdst count unusable with a mixture of VALU and TRANS. @@ -3234,10 +3292,9 @@ bool GCNHazardRecognizer::fixVALUReadSGPRHazard(MachineInstr *MI) { }; // Check for the hazard. - DenseSet Visited; int WaitStates = ::getWaitStatesSince(IsHazardFn, MI->getParent(), std::next(MI->getReverseIterator()), 0, - IsExpiredFn, Visited, WaitStatesFn); + IsExpiredFn, WaitStatesFn); if (WaitStates >= SALUExpiryCount) return false;