Skip to content

Commit bbf69d1

Browse files
committed
[AMDGPU] Add structural stall heuristic to scheduling strategies
Implements a structural stall heuristic that considers both resource hazards and latency constraints when selecting instructions from the pending queue. - Add getStructuralStallCycles() to GCNSchedStrategy that computes the number of cycles an instruction must wait due to: - Resource conflicts on unbuffered resources (from the SchedModel) - Sequence-dependent hazards (from GCNHazardRecognizer) - Add getHazardWaitStates() to GCNHazardRecognizer that returns the number of wait states until all hazards for an instruction are resolved, providing cycle-accurate hazard information for scheduling heuristics.
1 parent b639aa2 commit bbf69d1

File tree

7 files changed

+136
-6
lines changed

7 files changed

+136
-6
lines changed

llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
#include "AMDGPUMLSchedStrategy.h"
1515

16+
#include "llvm/Support/Debug.h"
17+
18+
#define DEBUG_TYPE "machine-scheduler"
19+
1620
using namespace llvm;
1721

1822
AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
@@ -121,6 +125,74 @@ bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
121125
return false;
122126
}
123127

128+
bool AMDGPUMLSchedStrategy::tryPendingCandidate(SchedCandidate &Cand,
129+
SchedCandidate &TryCand,
130+
SchedBoundary *Zone) const {
131+
// Initialize the candidate if needed.
132+
if (!Cand.isValid()) {
133+
TryCand.Reason = NodeOrder;
134+
return true;
135+
}
136+
137+
// Bias PhysReg Defs and copies to their uses and defined respectively.
138+
if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
139+
biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
140+
return TryCand.Reason != NoCand;
141+
142+
// Avoid exceeding the target's limit.
143+
if (DAG->isTrackingPressure() &&
144+
tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
145+
RegExcess, TRI, DAG->MF))
146+
return TryCand.Reason != NoCand;
147+
148+
// Avoid increasing the max critical pressure in the scheduled region.
149+
if (DAG->isTrackingPressure() &&
150+
tryPressure(TryCand.RPDelta.CriticalMax, Cand.RPDelta.CriticalMax,
151+
TryCand, Cand, RegCritical, TRI, DAG->MF))
152+
return TryCand.Reason != NoCand;
153+
154+
bool SameBoundary = Zone != nullptr;
155+
if (SameBoundary) {
156+
// Compare effective stall cycles between candidates.
157+
// Effective stall = max(structural stall, latency stall)
158+
// - Structural stalls: resource/hazard constraints (HW not ready)
159+
// - Latency stalls: data dependency constraints (operands not ready)
160+
//
161+
// This allows picking a pending instruction with structural stalls over
162+
// an available instruction with higher latency stalls (e.g., scheduling
163+
// a WMMA while waiting for a memory load result).
164+
unsigned TryStructStall = getStructuralStallCycles(*Zone, TryCand.SU);
165+
unsigned TryLatencyStall = Zone->getLatencyStallCycles(TryCand.SU);
166+
unsigned TryEffectiveStall = std::max(TryStructStall, TryLatencyStall);
167+
168+
unsigned CandStructStall = getStructuralStallCycles(*Zone, Cand.SU);
169+
unsigned CandLatencyStall = Zone->getLatencyStallCycles(Cand.SU);
170+
unsigned CandEffectiveStall = std::max(CandStructStall, CandLatencyStall);
171+
172+
LLVM_DEBUG(if (TryEffectiveStall || CandEffectiveStall) {
173+
dbgs() << "Effective stalls: try=" << TryEffectiveStall
174+
<< " (struct=" << TryStructStall << ", lat=" << TryLatencyStall
175+
<< ") cand=" << CandEffectiveStall
176+
<< " (struct=" << CandStructStall << ", lat=" << CandLatencyStall
177+
<< ")\n";
178+
});
179+
180+
if (tryLess(TryEffectiveStall, CandEffectiveStall, TryCand, Cand, Stall))
181+
return TryCand.Reason != NoCand;
182+
183+
TryCand.initResourceDelta(DAG, SchedModel);
184+
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
185+
TryCand, Cand, ResourceReduce))
186+
return TryCand.Reason != NoCand;
187+
if (tryGreater(TryCand.ResDelta.DemandedResources,
188+
Cand.ResDelta.DemandedResources, TryCand, Cand,
189+
ResourceDemand))
190+
return TryCand.Reason != NoCand;
191+
}
192+
193+
return false;
194+
}
195+
124196
AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
125197
const MachineSchedContext *C)
126198
: PostGenericScheduler(C) {}

llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
2020
protected:
2121
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
2222
SchedBoundary *Zone) const override;
23+
bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
24+
SchedBoundary *Zone) const override;
2325

2426
public:
2527
AMDGPUMLSchedStrategy(const MachineSchedContext *C);
@@ -33,4 +35,4 @@ class AMDGPUMLPostSchedStrategy : public PostGenericScheduler {
3335
AMDGPUMLPostSchedStrategy(const MachineSchedContext *C);
3436
};
3537

36-
} // End namespace llvm
38+
} // End namespace llvm

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ unsigned GCNHazardRecognizer::PreEmitNoops(MachineInstr *MI) {
313313
return std::max(W, NopPadding.getValue());
314314
}
315315

316+
unsigned GCNHazardRecognizer::getHazardWaitStates(MachineInstr *MI) const {
317+
return const_cast<GCNHazardRecognizer *>(this)->PreEmitNoopsCommon(MI);
318+
}
319+
316320
unsigned GCNHazardRecognizer::PreEmitNoopsCommon(MachineInstr *MI) {
317321
if (MI->isBundle())
318322
return 0;

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ class GCNHazardRecognizer final : public ScheduleHazardRecognizer {
145145
void EmitInstruction(SUnit *SU) override;
146146
void EmitInstruction(MachineInstr *MI) override;
147147
HazardType getHazardType(SUnit *SU, int Stalls) override;
148+
149+
/// Returns the number of wait states until all hazards for \p MI are
150+
/// resolved. This is useful for scheduling heuristics that want
151+
/// cycle-accurate hazard information rather than just a boolean. Unlike
152+
/// PreEmitNoops, this does not modify state or fix hazards.
153+
unsigned getHazardWaitStates(MachineInstr *MI) const;
148154
void EmitNoop() override;
149155
unsigned PreEmitNoops(MachineInstr *) override;
150156
unsigned PreEmitNoopsCommon(MachineInstr *);

llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include "GCNSchedStrategy.h"
2727
#include "AMDGPUIGroupLP.h"
28+
#include "GCNHazardRecognizer.h"
2829
#include "GCNRegPressure.h"
2930
#include "SIMachineFunctionInfo.h"
3031
#include "Utils/AMDGPUBaseInfo.h"
@@ -218,6 +219,40 @@ void GCNSchedStrategy::getRegisterPressures(
218219
Pressure[AMDGPU::RegisterPressureSets::AGPR_32] = NewPressure.getAGPRNum();
219220
}
220221

222+
unsigned GCNSchedStrategy::getStructuralStallCycles(SchedBoundary &Zone,
223+
SUnit *SU) const {
224+
// Only implemented for top-down scheduling currently.
225+
if (!Zone.isTop() || !SU)
226+
return 0;
227+
228+
MachineInstr *MI = SU->getInstr();
229+
unsigned CurrCycle = Zone.getCurrCycle();
230+
unsigned Stall = 0;
231+
232+
// Query SchedModel for resource stalls (unbuffered resources).
233+
if (SchedModel->hasInstrSchedModel() && SU->hasReservedResource) {
234+
const MCSchedClassDesc *SC = DAG->getSchedClass(SU);
235+
for (const MCWriteProcResEntry &PE :
236+
make_range(SchedModel->getWriteProcResBegin(SC),
237+
SchedModel->getWriteProcResEnd(SC))) {
238+
unsigned NextAvail =
239+
Zone.getNextResourceCycle(SC, PE.ProcResourceIdx, PE.ReleaseAtCycle,
240+
PE.AcquireAtCycle)
241+
.first;
242+
if (NextAvail > CurrCycle)
243+
Stall = std::max(Stall, NextAvail - CurrCycle);
244+
}
245+
}
246+
247+
// Query HazardRecognizer for sequence-dependent hazard penalties.
248+
if (Zone.HazardRec && Zone.HazardRec->isEnabled()) {
249+
auto *HR = static_cast<GCNHazardRecognizer *>(Zone.HazardRec);
250+
Stall = std::max(Stall, HR->getHazardWaitStates(MI));
251+
}
252+
253+
return Stall;
254+
}
255+
221256
void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
222257
bool AtTop,
223258
const RegPressureTracker &RPTracker,
@@ -673,6 +708,11 @@ bool GCNSchedStrategy::tryPendingCandidate(SchedCandidate &Cand,
673708

674709
bool SameBoundary = Zone != nullptr;
675710
if (SameBoundary) {
711+
unsigned TryStructStall = getStructuralStallCycles(*Zone, TryCand.SU);
712+
unsigned CandStructStall = getStructuralStallCycles(*Zone, Cand.SU);
713+
if (tryLess(TryStructStall, CandStructStall, TryCand, Cand, Stall))
714+
return TryCand.Reason != NoCand;
715+
676716
TryCand.initResourceDelta(DAG, SchedModel);
677717
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
678718
TryCand, Cand, ResourceReduce))

llvm/lib/Target/AMDGPU/GCNSchedStrategy.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class GCNSchedStrategy : public GenericScheduler {
5656
const SIRegisterInfo *SRI, unsigned SGPRPressure,
5757
unsigned VGPRPressure, bool IsBottomUp);
5858

59+
/// Estimate how many cycles \p SU must wait due to structural hazards at the
60+
/// current boundary cycle. Returns zero when no stall is required.
61+
unsigned getStructuralStallCycles(SchedBoundary &Zone, SUnit *SU) const;
62+
5963
/// Evaluates instructions in the pending queue using a subset of scheduling
6064
/// heuristics.
6165
///
@@ -64,8 +68,8 @@ class GCNSchedStrategy : public GenericScheduler {
6468
/// invisible to scheduling heuristics. However, in certain scenarios (such as
6569
/// avoiding register spilling), it may be beneficial to consider scheduling
6670
/// these not-yet-ready instructions.
67-
bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
68-
SchedBoundary *Zone) const;
71+
virtual bool tryPendingCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
72+
SchedBoundary *Zone) const;
6973

7074
void printCandidateDecision(const SchedCandidate &Current,
7175
const SchedCandidate &Preferred);

llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
attributes #1 = { "amdgpu-waves-per-eu"="1,1" }
1111
...
1212

13+
# The scheduler should reorder the use of the global load after WMMAs to hide memory latency.
14+
1315
---
1416
name: with_ml_workload_attr
1517
tracksRegLiveness: true
@@ -29,8 +31,8 @@ body: |
2931
; DEFAULT-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
3032
; DEFAULT-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
3133
; DEFAULT-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
32-
; DEFAULT-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
3334
; DEFAULT-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
35+
; DEFAULT-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
3436
; DEFAULT-NEXT: S_ENDPGM 0, implicit [[V_PK_ADD_F32_]], implicit %13, implicit %14
3537
;
3638
; ML-LABEL: name: with_ml_workload_attr
@@ -47,8 +49,8 @@ body: |
4749
; ML-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
4850
; ML-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
4951
; ML-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
50-
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
5152
; ML-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
53+
; ML-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
5254
; ML-NEXT: S_ENDPGM 0, implicit [[V_PK_ADD_F32_]], implicit %13, implicit %14
5355
%0:vreg_512_align2 = IMPLICIT_DEF
5456
%1:vreg_512_align2 = IMPLICIT_DEF
@@ -99,8 +101,8 @@ body: |
99101
; ML-NEXT: [[DEF3:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
100102
; ML-NEXT: [[DEF4:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
101103
; ML-NEXT: [[DEF5:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
102-
; ML-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
103104
; ML-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF5]], 0, 0, implicit $exec
105+
; ML-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
104106
; ML-NEXT: [[DEF6:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
105107
; ML-NEXT: [[DEF7:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
106108
; ML-NEXT: [[DEF8:%[0-9]+]]:vreg_256_align2 = IMPLICIT_DEF

0 commit comments

Comments
 (0)