Skip to content

Commit 305fb4f

Browse files
committed
[AMDGPU] Add structural stall heuristic to scheduling strategies
Implements a structural stall heuristic that considers both resource hazards and latency constraints when selecting instructions from the pending queue. - Add getStructuralStallCycles() to GCNSchedStrategy that computes the number of cycles an instruction must wait due to: - Resource conflicts on unbuffered resources (from the SchedModel) - Sequence-dependent hazards (from GCNHazardRecognizer) - Add getHazardWaitStates() to GCNHazardRecognizer that returns the number of wait states until all hazards for an instruction are resolved, providing cycle-accurate hazard information for scheduling heuristics.
1 parent 4d99b6d commit 305fb4f

File tree

7 files changed

+131
-5
lines changed

7 files changed

+131
-5
lines changed

llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
#include "AMDGPUMLSchedStrategy.h"
1515

16+
#include "llvm/Support/Debug.h"
17+
18+
#define DEBUG_TYPE "machine-scheduler"
19+
1620
using namespace llvm;
1721

1822
AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
@@ -130,6 +134,74 @@ bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
130134
return false;
131135
}
132136

137+
bool AMDGPUMLSchedStrategy::tryPendingCandidate(SchedCandidate &Cand,
138+
SchedCandidate &TryCand,
139+
SchedBoundary *Zone) const {
140+
// Initialize the candidate if needed.
141+
if (!Cand.isValid()) {
142+
TryCand.Reason = NodeOrder;
143+
return true;
144+
}
145+
146+
// Bias PhysReg Defs and copies to their uses and defined respectively.
147+
if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
148+
biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
149+
return TryCand.Reason != NoCand;
150+
151+
// Avoid exceeding the target's limit.
152+
if (DAG->isTrackingPressure() &&
153+
tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
154+
RegExcess, TRI, DAG->MF))
155+
return TryCand.Reason != NoCand;
156+
157+
// Avoid increasing the max critical pressure in the scheduled region.
158+
if (DAG->isTrackingPressure() &&
159+
tryPressure(TryCand.RPDelta.CriticalMax, Cand.RPDelta.CriticalMax,
160+
TryCand, Cand, RegCritical, TRI, DAG->MF))
161+
return TryCand.Reason != NoCand;
162+
163+
bool SameBoundary = Zone != nullptr;
164+
if (SameBoundary) {
165+
// Compare effective stall cycles between candidates.
166+
// Effective stall = max(structural stall, latency stall)
167+
// - Structural stalls: resource/hazard constraints (HW not ready)
168+
// - Latency stalls: data dependency constraints (operands not ready)
169+
//
170+
// This allows picking a pending instruction with structural stalls over
171+
// an available instruction with higher latency stalls (e.g., scheduling
172+
// a WMMA while waiting for a memory load result).
173+
unsigned TryStructStall = getStructuralStallCycles(*Zone, TryCand.SU);
174+
unsigned TryLatencyStall = Zone->getLatencyStallCycles(TryCand.SU);
175+
unsigned TryEffectiveStall = std::max(TryStructStall, TryLatencyStall);
176+
177+
unsigned CandStructStall = getStructuralStallCycles(*Zone, Cand.SU);
178+
unsigned CandLatencyStall = Zone->getLatencyStallCycles(Cand.SU);
179+
unsigned CandEffectiveStall = std::max(CandStructStall, CandLatencyStall);
180+
181+
LLVM_DEBUG(if (TryEffectiveStall || CandEffectiveStall) {
182+
dbgs() << "Effective stalls: try=" << TryEffectiveStall
183+
<< " (struct=" << TryStructStall << ", lat=" << TryLatencyStall
184+
<< ") cand=" << CandEffectiveStall
185+
<< " (struct=" << CandStructStall << ", lat=" << CandLatencyStall
186+
<< ")\n";
187+
});
188+
189+
if (tryLess(TryEffectiveStall, CandEffectiveStall, TryCand, Cand, Stall))
190+
return TryCand.Reason != NoCand;
191+
192+
TryCand.initResourceDelta(DAG, SchedModel);
193+
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
194+
TryCand, Cand, ResourceReduce))
195+
return TryCand.Reason != NoCand;
196+
if (tryGreater(TryCand.ResDelta.DemandedResources,
197+
Cand.ResDelta.DemandedResources, TryCand, Cand,
198+
ResourceDemand))
199+
return TryCand.Reason != NoCand;
200+
}
201+
202+
return false;
203+
}
204+
133205
AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
134206
const MachineSchedContext *C)
135207
: PostGenericScheduler(C) {}

llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
2424
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
2525
SchedBoundary *Zone) const override;
2626

27+
bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
28+
SchedBoundary *Zone) const override;
29+
2730
public:
2831
AMDGPUMLSchedStrategy(const MachineSchedContext *C);
2932

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ unsigned GCNHazardRecognizer::PreEmitNoops(MachineInstr *MI) {
313313
return std::max(W, NopPadding.getValue());
314314
}
315315

316+
unsigned GCNHazardRecognizer::getHazardWaitStates(MachineInstr *MI) const {
317+
return const_cast<GCNHazardRecognizer *>(this)->PreEmitNoopsCommon(MI);
318+
}
319+
316320
unsigned GCNHazardRecognizer::PreEmitNoopsCommon(MachineInstr *MI) {
317321
if (MI->isBundle())
318322
return 0;

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ class GCNHazardRecognizer final : public ScheduleHazardRecognizer {
145145
void EmitInstruction(SUnit *SU) override;
146146
void EmitInstruction(MachineInstr *MI) override;
147147
HazardType getHazardType(SUnit *SU, int Stalls) override;
148+
149+
/// Returns the number of wait states until all hazards for \p MI are
150+
/// resolved. This is useful for scheduling heuristics that want
151+
/// cycle-accurate hazard information rather than just a boolean. Unlike
152+
/// PreEmitNoops, this does not modify state or fix hazards.
153+
unsigned getHazardWaitStates(MachineInstr *MI) const;
148154
void EmitNoop() override;
149155
unsigned PreEmitNoops(MachineInstr *) override;
150156
unsigned PreEmitNoopsCommon(MachineInstr *);

llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "GCNSchedStrategy.h"
2727
#include "AMDGPUIGroupLP.h"
28+
#include "GCNHazardRecognizer.h"
2829
#include "GCNRegPressure.h"
2930
#include "SIMachineFunctionInfo.h"
3031
#include "Utils/AMDGPUBaseInfo.h"
@@ -218,6 +219,40 @@ void GCNSchedStrategy::getRegisterPressures(
218219
Pressure[AMDGPU::RegisterPressureSets::AGPR_32] = NewPressure.getAGPRNum();
219220
}
220221

222+
unsigned GCNSchedStrategy::getStructuralStallCycles(SchedBoundary &Zone,
223+
SUnit *SU) const {
224+
// Only implemented for top-down scheduling currently.
225+
if (!Zone.isTop() || !SU)
226+
return 0;
227+
228+
MachineInstr *MI = SU->getInstr();
229+
unsigned CurrCycle = Zone.getCurrCycle();
230+
unsigned Stall = 0;
231+
232+
// Query SchedModel for resource stalls (unbuffered resources).
233+
if (SchedModel->hasInstrSchedModel() && SU->hasReservedResource) {
234+
const MCSchedClassDesc *SC = DAG->getSchedClass(SU);
235+
for (const MCWriteProcResEntry &PE :
236+
make_range(SchedModel->getWriteProcResBegin(SC),
237+
SchedModel->getWriteProcResEnd(SC))) {
238+
unsigned NextAvail =
239+
Zone.getNextResourceCycle(SC, PE.ProcResourceIdx, PE.ReleaseAtCycle,
240+
PE.AcquireAtCycle)
241+
.first;
242+
if (NextAvail > CurrCycle)
243+
Stall = std::max(Stall, NextAvail - CurrCycle);
244+
}
245+
}
246+
247+
// Query HazardRecognizer for sequence-dependent hazard penalties.
248+
if (Zone.HazardRec && Zone.HazardRec->isEnabled()) {
249+
auto *HR = static_cast<GCNHazardRecognizer *>(Zone.HazardRec);
250+
Stall = std::max(Stall, HR->getHazardWaitStates(MI));
251+
}
252+
253+
return Stall;
254+
}
255+
221256
void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
222257
bool AtTop,
223258
const RegPressureTracker &RPTracker,

llvm/lib/Target/AMDGPU/GCNSchedStrategy.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class GCNSchedStrategy : public GenericScheduler {
5656
const SIRegisterInfo *SRI, unsigned SGPRPressure,
5757
unsigned VGPRPressure, bool IsBottomUp);
5858

59+
/// Estimate how many cycles \p SU must wait due to structural hazards at the
60+
/// current boundary cycle. Returns zero when no stall is required.
61+
unsigned getStructuralStallCycles(SchedBoundary &Zone, SUnit *SU) const;
62+
5963
/// Evaluates instructions in the pending queue using a subset of scheduling
6064
/// heuristics.
6165
///
@@ -64,8 +68,8 @@ class GCNSchedStrategy : public GenericScheduler {
6468
/// invisible to scheduling heuristics. However, in certain scenarios (such as
6569
/// avoiding register spilling), it may be beneficial to consider scheduling
6670
/// these not-yet-ready instructions.
67-
bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
68-
SchedBoundary *Zone) const;
71+
virtual bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
72+
SchedBoundary *Zone) const;
6973

7074
void printCandidateDecision(const SchedCandidate &Current,
7175
const SchedCandidate &Preferred);

llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
attributes #1 = { "amdgpu-waves-per-eu"="1,1" }
1313
...
1414

15+
# The scheduler should reorder the use of the global load after WMMAs to hide memory latency.
16+
1517
---
1618
name: with_ml_workload_attr
1719
tracksRegLiveness: true
@@ -31,8 +33,8 @@ body: |
3133
; DEFAULT-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
3234
; DEFAULT-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
3335
; DEFAULT-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
34-
; DEFAULT-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
3536
; DEFAULT-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
37+
; DEFAULT-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
3638
; DEFAULT-NEXT: S_ENDPGM 0, implicit [[V_PK_ADD_F32_]], implicit %13, implicit %14
3739
;
3840
; ML-LABEL: name: with_ml_workload_attr
@@ -49,8 +51,8 @@ body: |
4951
; ML-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
5052
; ML-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
5153
; ML-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
52-
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
5354
; ML-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
55+
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
5456
; ML-NEXT: S_ENDPGM 0, implicit [[V_PK_ADD_F32_]], implicit %13, implicit %14
5557
%0:vreg_512_align2 = IMPLICIT_DEF
5658
%1:vreg_512_align2 = IMPLICIT_DEF
@@ -108,8 +110,8 @@ body: |
108110
; ML-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
109111
; ML-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
110112
; ML-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
111-
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
112113
; ML-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
114+
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
113115
; ML-NEXT: S_ENDPGM 0, implicit [[V_PK_ADD_F32_]], implicit %13, implicit %14
114116
%0:vreg_512_align2 = IMPLICIT_DEF
115117
%1:vreg_512_align2 = IMPLICIT_DEF

0 commit comments

Comments
 (0)