Skip to content

Conversation

@perlfu
Copy link
Contributor

@perlfu perlfu commented Sep 12, 2024

Refactor getWaitStatesSince:

  • Remove recursion to avoid excess stack usage with newer hazards
  • Ensure algorithm always returns minima to hazards by allowing revisiting of blocks if a shorter path is encountered.
  • Reduce the search space by actively pruning deeper search after a minimum is established.

Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.

Refactor getWaitStatesSince:
* Remove recursion to avoid excess stack usage with newer hazards
* Ensure algorithm always returns minima to hazards by allowing
  revisiting of blocks if a shorter path is encountered.
* Reduce the search space by actively pruning deeper search after
  a minimum is established.

Note: in edge cases this might be slightly slower as it now
searches to find the true minimum number of wait states.
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Carl Ritson (perlfu)

Changes

Refactor getWaitStatesSince:

  • Remove recursion to avoid excess stack usage with newer hazards
  • Ensure algorithm always returns minima to hazards by allowing revisiting of blocks if a shorter path is encountered.
  • Reduce the search space by actively pruning deeper search after a minimum is established.

Note: in edge cases this might be slightly slower as it now searches to find the true minimum number of wait states.


Full diff: https://github.com/llvm/llvm-project/pull/108347.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+74-22)
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index cc39fd1740683f..5150fabd173fa6 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -19,6 +19,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/ScheduleDAG.h"
 #include "llvm/TargetParser/TargetParser.h"
+#include <queue>
 
 using namespace llvm;
 
@@ -495,21 +496,18 @@ hasHazard(StateT State,
   return false;
 }
 
-// Returns a minimum wait states since \p I walking all predecessors.
-// Only scans until \p IsExpired does not return true.
-// Can only be run in a hazard recognizer mode.
-static int getWaitStatesSince(
+// Update \p WaitStates while iterating from \p I to hazard in \p MBB.
+static HazardFnResult countWaitStatesSince(
     GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
-    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
-    IsExpiredFn IsExpired, DenseSet<const MachineBasicBlock *> &Visited,
-    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+    MachineBasicBlock::const_reverse_instr_iterator I, int &WaitStates,
+    IsExpiredFn IsExpired, GetNumWaitStatesFn GetNumWaitStates) {
   for (auto E = MBB->instr_rend(); I != E; ++I) {
     // Don't add WaitStates for parent BUNDLE instructions.
     if (I->isBundle())
       continue;
 
     if (IsHazard(*I))
-      return WaitStates;
+      return HazardFound;
 
     if (I->isInlineAsm())
       continue;
@@ -517,29 +515,85 @@ static int getWaitStatesSince(
     WaitStates += GetNumWaitStates(*I);
 
     if (IsExpired(*I, WaitStates))
-      return std::numeric_limits<int>::max();
+      return HazardExpired;
+  }
+
+  return NoHazardFound;
+}
+
+// Implements predecessor search for getWaitStatesSince.
+static int getWaitStatesSinceImpl(
+    GCNHazardRecognizer::IsHazardFn IsHazard,
+    const MachineBasicBlock *InitialMBB, int InitialWaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  DenseMap<const MachineBasicBlock *, int> Visited;
+
+  // Build worklist of predecessors.
+  // Note: use queue so search is breadth first, which reduces search space
+  // when a hazard is found.
+  std::queue<const MachineBasicBlock *> Worklist;
+  for (MachineBasicBlock *Pred : InitialMBB->predecessors()) {
+    Visited[Pred] = InitialWaitStates;
+    Worklist.push(Pred);
   }
 
+  // Find minimum wait states to hazard or determine that all paths expire.
   int MinWaitStates = std::numeric_limits<int>::max();
-  for (MachineBasicBlock *Pred : MBB->predecessors()) {
-    if (!Visited.insert(Pred).second)
-      continue;
+  while (!Worklist.empty()) {
+    const MachineBasicBlock *MBB = Worklist.front();
+    int WaitStates = Visited[MBB];
+    Worklist.pop();
 
-    int W = getWaitStatesSince(IsHazard, Pred, Pred->instr_rbegin(), WaitStates,
-                               IsExpired, Visited, GetNumWaitStates);
+    // No reason to search blocks when wait states exceed established minimum.
+    if (WaitStates >= MinWaitStates)
+      continue;
 
-    MinWaitStates = std::min(MinWaitStates, W);
+    // Search for hazard
+    auto Search = countWaitStatesSince(IsHazard, MBB, MBB->instr_rbegin(),
+                                       WaitStates, IsExpired, GetNumWaitStates);
+    if (Search == HazardFound) {
+      // Update minimum.
+      MinWaitStates = std::min(MinWaitStates, WaitStates);
+    } else if (Search == NoHazardFound && WaitStates < MinWaitStates) {
+      // Search predecessors.
+      for (MachineBasicBlock *Pred : MBB->predecessors()) {
+        if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
+          // Store lowest wait states required to visit this block.
+          Visited[Pred] = WaitStates;
+          Worklist.push(Pred);
+        }
+      }
+    }
   }
 
   return MinWaitStates;
 }
 
+// Returns minimum wait states since \p I walking all predecessors.
+// Only scans until \p IsExpired does not return true.
+// Can only be run in a hazard recognizer mode.
+static int getWaitStatesSince(
+    GCNHazardRecognizer::IsHazardFn IsHazard, const MachineBasicBlock *MBB,
+    MachineBasicBlock::const_reverse_instr_iterator I, int WaitStates,
+    IsExpiredFn IsExpired,
+    GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
+  // Scan this block from I.
+  auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
+                                         IsExpired, GetNumWaitStates);
+  if (InitSearch == HazardFound)
+    return WaitStates;
+  else if (InitSearch == HazardExpired)
+    return std::numeric_limits<int>::max();
+  else
+    return getWaitStatesSinceImpl(IsHazard, MBB, WaitStates, IsExpired,
+                                  GetNumWaitStates);
+}
+
 static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
                               const MachineInstr *MI, IsExpiredFn IsExpired) {
-  DenseSet<const MachineBasicBlock *> Visited;
   return getWaitStatesSince(IsHazard, MI->getParent(),
-                            std::next(MI->getReverseIterator()),
-                            0, IsExpired, Visited);
+                            std::next(MI->getReverseIterator()), 0, IsExpired);
 }
 
 int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1524,10 +1578,9 @@ bool GCNHazardRecognizer::fixLdsDirectVALUHazard(MachineInstr *MI) {
     return SIInstrInfo::isVALU(MI) ? 1 : 0;
   };
 
-  DenseSet<const MachineBasicBlock *> Visited;
   auto Count = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
                                     std::next(MI->getReverseIterator()), 0,
-                                    IsExpiredFn, Visited, GetWaitStatesFn);
+                                    IsExpiredFn, GetWaitStatesFn);
 
   // Transcendentals can execute in parallel to other VALUs.
   // This makes va_vdst count unusable with a mixture of VALU and TRANS.
@@ -3234,10 +3287,9 @@ bool GCNHazardRecognizer::fixVALUReadSGPRHazard(MachineInstr *MI) {
   };
 
   // Check for the hazard.
-  DenseSet<const MachineBasicBlock *> Visited;
   int WaitStates = ::getWaitStatesSince(IsHazardFn, MI->getParent(),
                                         std::next(MI->getReverseIterator()), 0,
-                                        IsExpiredFn, Visited, WaitStatesFn);
+                                        IsExpiredFn, WaitStatesFn);
 
   if (WaitStates >= SALUExpiryCount)
     return false;

@jayfoad
Copy link
Contributor

jayfoad commented Sep 12, 2024

+1 for the idea of this patch. It will take me a while to review the details.

It would be great to see a test case that demonstrates the shortest-path behaviour. I think this is technically a bug fix, right?

Incidentally there is another form of recursion in fixLdsBranchVmemWARHazard where the HazardFn calls back into getWaitStatesSince. But fixing that would require a completely different kind of refactor.

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic looks good to me. I just added some nits and musings inline.

}

// Implements predecessor search for getWaitStatesSince.
static int getWaitStatesSinceImpl(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this to be split out into a separate function, now that it is not recursive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Allow renaming of some of the variables to avoid programming errors (on my part).
  • Hint the compiler that this is the "out of line" case.

if (!Visited.contains(Pred) || WaitStates < Visited[Pred]) {
// Store lowest wait states required to visit this block.
Visited[Pred] = WaitStates;
Worklist.push(Pred);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC: in the case where you just want to update Visited[Pred] to a new lower WaitStates value, this might add a second copy of Pred to the worklist, which would be wasteful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of the duplication possibility, but every test I have done suggests the general overhead of adding a set or similar to avoid duplicate visits adds unnecessary cost.
I agree an anti-patterns that could exist that represent a worst case, but in practice I don't think this is an issue.

IsExpiredFn IsExpired,
GetNumWaitStatesFn GetNumWaitStates = SIInstrInfo::getNumWaitStates) {
// Scan this block from I.
auto InitSearch = countWaitStatesSince(IsHazard, MBB, I, WaitStates,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than call countWaitStatesSince here, can't you immediately call into getWaitStatesSinceImpl but initialize the worklist to just MBB? I.e. common up this function with the body of the loop in getWaitStatesSinceImpl? I guess the complexity is that each item on the worklist would have to be an (MBB, iterator) pair, so that you know at what point within each MBB to start searching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this approach in an early draft, but it complicates the logic and adds memory usage for no real benefit.

// Build worklist of predecessors.
// Note: use queue so search is breadth first, which reduces search space
// when a hazard is found.
std::queue<const MachineBasicBlock *> Worklist;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more efficient to use SmallVector Worklist and Worklist.push_back() for inserting and Worklist[Idx++] for popping the front of the queue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, testing agrees using SmallVector is a bit faster (probably more so in common cases).
However, this method does mean the worklist will essentially grow to compass the entire search.
I have added some code to reclaim space when the list is (very) large.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

However, this method does mean the worklist will essentially grow to compass the entire search.
I have added some code to reclaim space when the list is (very) large.

I don't see why that is a problem at all. The absolute worst that can happen is we get a vector with one pointer per BB in the function. I don't think it's worth worrying about that amount of memory usage.

@perlfu
Copy link
Contributor Author

perlfu commented Sep 15, 2024

The logic looks good to me. I just added some nits and musings inline.

Thanks.
I still need to add a test.

A follow up is to flatten hasHazard and convert some of the current uses of getWaitStatesSince to hasHazard.
This is because they don't actually need the wait state count, so shouldn't be incurring the cost of this search.

@perlfu perlfu closed this Oct 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants