-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AMDGPU] GCNHazardRecognizer: refactor getWaitStatesSince (NFCI) #108347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Refactor getWaitStatesSince: * Remove recursion to avoid excess stack usage with newer hazards * Ensure algorithm always returns minima to hazards by allowing revisiting of blocks if a shorter path is encountered. * Reduce the search space by actively pruning deeper search after a minimum is established. Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.
|
@llvm/pr-subscribers-backend-amdgpu Author: Carl Ritson (perlfu) ChangesRefactor getWaitStatesSince:
Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states. Full diff: https://github.com/llvm/llvm-project/pull/108347.diff 1 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index cc39fd1740683f..5150fabd173fa6 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -19,6 +19,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/ScheduleDAG.h"
#include "llvm/TargetParser/TargetParser.h"
+#include <queue>
using namespace llvm;
@@ -495,21 +496,18 @@ hasHazard(StateT State,
return false;
}
-// Returns a minimum wait states since \p I walking all predecessors.
-// Only scans until \p IsExpired does not return true.
-// Can only be run in a hazard recognizer mode.
-static int getWaitStatesSince(
+// Update \p WaitStates while iterating from \p I to hazard in \p MBB.
+static HazardFnResult countWaitStatesSince(
GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
- MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
- IsExpiredFn IsExpired, DenseSet<const MachineBasicBlock *> &Visited,
- GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates,
+ IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) {
for (auto E = MBB->instr_rend(); I != E; ++I) {
// Don't add WaitStates for parent BUNDLE instructions.
if (I->isBundle())
continue;
if (IsHazard(*I))
- return WaitStates;
+ return HazardFound;
if (I->isInlineAsm())
continue;
@@ -517,29 +515,85 @@ static int getWaitStatesSince(
WaitStates += GetNumWaitStates(*I);
if (IsExpired(*I, WaitStates))
- return std::numeric_limits<int>::max();
+ return HazardExpired;
+ }
+
+ return NoHazardFound;
+}
+
+// Implements predecessor search for getWaitStatesSince.
+static int getWaitStatesSinceImpl(
+ GCNHazardRecognizer::IsHazardFn IsHazard,
+ const MachineBasicBlock *InitialMBB, int InitialWaitStates,
+ IsExpiredFn IsExpired,
+ GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ DenseMap<const MachineBasicBlock *, int> Visited;
+
+ // Build worklist of predecessors.
+ // Note: use queue so search is breadth first, which reduces search space
+ // when a hazard is found.
+ std::queue<const MachineBasicBlock *> Worklist;
+ for (MachineBasicBlock *Pred : InitialMBB->predecessors()) {
+ Visited[Pred] = InitialWaitStates;
+ Worklist.push(Pred);
}
+ // Find minimum wait states to hazard or determine that all paths expire.
int MinWaitStates = std::numeric_limits<int>::max();
- for (MachineBasicBlock *Pred : MBB->predecessors()) {
- if (!Visited.insert(Pred).second)
- continue;
+ while (!Worklist.empty()) {
+ const MachineBasicBlock *MBB = Worklist.front();
+ int WaitStates = Visited[MBB];
+ Worklist.pop();
- int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates,
- IsExpired, Visited, GetNumWaitStates);
+ // No reason to search blocks when wait states exceed established minimum.
+ if (WaitStates >= MinWaitStates)
+ continue;
- MinWaitStates = std::min(MinWaitStates, W);
+ // Search for hazard
+ auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(),
+ WaitStates, IsExpired, GetNumWaitStates);
+ if (Search == HazardFound) {
+ // Update minimum.
+ MinWaitStates = std::min(MinWaitStates, WaitStates);
+ } else if (Search == NoHazardFound && WaitStates < MinWaitStates) {
+ // Search predecessors.
+ for (MachineBasicBlock *Pred : MBB->predecessors()) {
+ if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
+ // Store lowest wait states required to visit this block.
+ Visited[Pred] = WaitStates;
+ Worklist.push(Pred);
+ }
+ }
+ }
}
return MinWaitStates;
}
+// Returns minimum wait states since \p I walking all predecessors.
+// Only scans until \p IsExpired does not return true.
+// Can only be run in a hazard recognizer mode.
+static int getWaitStatesSince(
+ GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
+ MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
+ IsExpiredFn IsExpired,
+ GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+ // Scan this block from I.
+ auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
+ IsExpired, GetNumWaitStates);
+ if (InitSearch == HazardFound)
+ return WaitStates;
+ else if (InitSearch == HazardExpired)
+ return std::numeric_limits<int>::max();
+ else
+ return getWaitStatesSinceImpl(IsHazard, MBB, WaitStates, IsExpired,
+ GetNumWaitStates);
+}
+
static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
const MachineInstr *MI, IsExpiredFn IsExpired) {
- DenseSet<const MachineBasicBlock *> Visited;
return getWaitStatesSince(IsHazard, MI->getParent(),
- std::next(MI->getReverseIterator()),
- 0, IsExpired, Visited);
+ std::next(MI->getReverseIterator()), 0, IsExpired);
}
int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1524,10 +1578,9 @@ bool GCNHazardRecognizer::fixLdsDirectVALUHazard(MachineInstr *MI) {
return SIInstrInfo::isVALU(MI) ? 1 : 0;
};
- DenseSet<const MachineBasicBlock *> Visited;
auto Count = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
std::next(MI->getReverseIterator()), 0,
- IsExpiredFn, Visited, GetWaitStatesFn);
+ IsExpiredFn, GetWaitStatesFn);
// Transcendentals can execute in parallel to other VALUs.
// This makes va_vdst count unusable with a mixture of VALU and TRANS.
@@ -3234,10 +3287,9 @@ bool GCNHazardRecognizer::fixVALUReadSGPRHazard(MachineInstr *MI) {
};
// Check for the hazard.
- DenseSet<const MachineBasicBlock *> Visited;
int WaitStates = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
std::next(MI->getReverseIterator()), 0,
- IsExpiredFn, Visited, WaitStatesFn);
+ IsExpiredFn, WaitStatesFn);
if (WaitStates >= SALUExpiryCount)
return false;
|
|
+1 for the idea of this patch. It will take me a while to review the details. It would be great to see a test case that demonstrates the shortest-path behaviour. I think this is technically a bug fix, right? Incidentally there is another form of recursion in |
jayfoad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks good to me. I just added some nits and musings inline.
| } | ||
|
|
||
| // Implements predecessor search for getWaitStatesSince. | ||
| static int getWaitStatesSinceImpl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason for this to be split out into a separate function, now that it is not recursive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Allow renaming of some of the variables to avoid programming errors (on my part).
- Hint the compiler that this is the "out of line" case.
| if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) { | ||
| // Store lowest wait states required to visit this block. | ||
| Visited[Pred] = WaitStates; | ||
| Worklist.push(Pred); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC: in the case where you just want to update Visited[Pred] to a new lower WaitStates value, this might add a second copy of Pred to the worklist, which would be wasteful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm aware of the duplication possibility, but every test I have done suggests the general overhead of adding a set or similar to avoid duplicate visits adds unnecessary cost.
I agree an anti-patterns that could exist that represent a worst case, but in practice I don't think this is an issue.
| IsExpiredFn IsExpired, | ||
| GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) { | ||
| // Scan this block from I. | ||
| auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than call countWaitStatesSince here, can't you immediately call into getWaitStatesSinceImpl but initialize the worklist to just MBB? I.e. common up this function with the body of the loop in getWaitStatesSinceImpl? I guess the complexity is that each item on the worklist would have to be an (MBB, iterator) pair, so that you know at what point within each MBB to start searching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used this approach in an early draft, but it complicates the logic and adds memory usage for no real benefit.
| // Build worklist of predecessors. | ||
| // Note: use queue so search is breadth first, which reduces search space | ||
| // when a hazard is found. | ||
| std::queue<const MachineBasicBlock *> Worklist; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be more efficient to use SmallVector Worklist and Worklist.push_back() for inserting and Worklist[Idx++] for popping the front of the queue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, testing agrees using SmallVector is a bit faster (probably more so in common cases).
However, this method does mean the worklist will essentially grow to compass the entire search.
I have added some code to reclaim space when the list is (very) large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
However, this method does mean the worklist will essentially grow to compass the entire search.
I have added some code to reclaim space when the list is (very) large.
I don't see why that is a problem at all. The absolute worst that can happen is we get a vector with one pointer per BB in the function. I don't think it's worth worrying about that amount of memory usage.
Thanks. A follow up is to flatten |
Refactor getWaitStatesSince:
Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.