Skip to content

Commit b639aa2

Browse files
committed
[AMDGPU] Add scaffolding for ML focused scheduling strategy
This patch introduces scaffolding for a new machine instruction scheduling strategy optimized for machine learning workloads. Enable the ML scheduler automatically when functions have the "amdgpu-workload-type"="ml" attribute.
1 parent 21e0b56 commit b639aa2

File tree

8 files changed

+395
-13
lines changed

8 files changed

+395
-13
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
//===-- AMDGPUMLSchedStrategy.cpp - ML-focused Scheduler Strategy ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// ML-focused scheduling strategy for AMDGPU.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "AMDGPUMLSchedStrategy.h"
15+
16+
using namespace llvm;
17+
18+
AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
19+
: GCNSchedStrategy(C) {
20+
SchedStages.push_back(GCNSchedStageID::ILPInitialSchedule);
21+
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
22+
// Use more accurate GCN pressure trackers.
23+
UseGCNTrackers = true;
24+
}
25+
26+
bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
27+
SchedCandidate &TryCand,
28+
SchedBoundary *Zone) const {
29+
// Initialize the candidate if needed.
30+
if (!Cand.isValid()) {
31+
TryCand.Reason = FirstValid;
32+
return true;
33+
}
34+
35+
// Bias PhysReg Defs and copies to their uses and defined respectively.
36+
if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
37+
biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
38+
return TryCand.Reason != NoCand;
39+
40+
// Avoid exceeding the target's limit.
41+
if (DAG->isTrackingPressure() &&
42+
tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
43+
RegExcess, TRI, DAG->MF))
44+
return TryCand.Reason != NoCand;
45+
46+
// We only compare a subset of features when comparing nodes between
47+
// Top and Bottom boundary. Some properties are simply incomparable, in many
48+
// other instances we should only override the other boundary if something
49+
// is a clear good pick on one boundary. Skip heuristics that are more
50+
// "tie-breaking" in nature.
51+
bool SameBoundary = Zone != nullptr;
52+
if (SameBoundary) {
53+
// For loops that are acyclic path limited, aggressively schedule for
54+
// latency. Within an single cycle, whenever CurrMOps > 0, allow normal
55+
// heuristics to take precedence.
56+
if (Rem.IsAcyclicLatencyLimited && !Zone->getCurrMOps() &&
57+
tryLatency(TryCand, Cand, *Zone))
58+
return TryCand.Reason != NoCand;
59+
60+
// Prioritize instructions that read unbuffered resources by stall cycles.
61+
if (tryLess(Zone->getLatencyStallCycles(TryCand.SU),
62+
Zone->getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
63+
return TryCand.Reason != NoCand;
64+
}
65+
66+
// Keep clustered nodes together to encourage downstream peephole
67+
// optimizations which may reduce resource requirements.
68+
//
69+
// This is a best effort to set things up for a post-RA pass. Optimizations
70+
// like generating loads of multiple registers should ideally be done within
71+
// the scheduler pass by combining the loads during DAG postprocessing.
72+
unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
73+
unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
74+
bool CandIsClusterSucc =
75+
isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
76+
bool TryCandIsClusterSucc =
77+
isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);
78+
79+
if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
80+
Cluster))
81+
return TryCand.Reason != NoCand;
82+
83+
if (SameBoundary) {
84+
// Weak edges are for clustering and other constraints.
85+
if (tryLess(getWeakLeft(TryCand.SU, TryCand.AtTop),
86+
getWeakLeft(Cand.SU, Cand.AtTop), TryCand, Cand, Weak))
87+
return TryCand.Reason != NoCand;
88+
}
89+
90+
// Avoid increasing the max pressure of the entire region.
91+
if (DAG->isTrackingPressure() &&
92+
tryPressure(TryCand.RPDelta.CurrentMax, Cand.RPDelta.CurrentMax, TryCand,
93+
Cand, RegMax, TRI, DAG->MF))
94+
return TryCand.Reason != NoCand;
95+
96+
if (SameBoundary) {
97+
// Avoid critical resource consumption and balance the schedule.
98+
TryCand.initResourceDelta(DAG, SchedModel);
99+
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
100+
TryCand, Cand, ResourceReduce))
101+
return TryCand.Reason != NoCand;
102+
if (tryGreater(TryCand.ResDelta.DemandedResources,
103+
Cand.ResDelta.DemandedResources, TryCand, Cand,
104+
ResourceDemand))
105+
return TryCand.Reason != NoCand;
106+
107+
// Avoid serializing long latency dependence chains.
108+
// For acyclic path limited loops, latency was already checked above.
109+
if (!RegionPolicy.DisableLatencyHeuristic && TryCand.Policy.ReduceLatency &&
110+
!Rem.IsAcyclicLatencyLimited && tryLatency(TryCand, Cand, *Zone))
111+
return TryCand.Reason != NoCand;
112+
113+
// Fall through to original instruction order.
114+
if ((Zone->isTop() && TryCand.SU->NodeNum < Cand.SU->NodeNum) ||
115+
(!Zone->isTop() && TryCand.SU->NodeNum > Cand.SU->NodeNum)) {
116+
TryCand.Reason = NodeOrder;
117+
return true;
118+
}
119+
}
120+
121+
return false;
122+
}
123+
124+
AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
125+
const MachineSchedContext *C)
126+
: PostGenericScheduler(C) {}
127+
128+
bool AMDGPUMLPostSchedStrategy::tryCandidate(SchedCandidate &Cand,
129+
SchedCandidate &TryCand) {
130+
// Initialize the candidate if needed.
131+
if (!Cand.isValid()) {
132+
TryCand.Reason = FirstValid;
133+
return true;
134+
}
135+
136+
// Prioritize instructions that read unbuffered resources by stall cycles.
137+
if (tryLess(Top.getLatencyStallCycles(TryCand.SU),
138+
Top.getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
139+
return TryCand.Reason != NoCand;
140+
141+
// Keep clustered nodes together.
142+
unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
143+
unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
144+
bool CandIsClusterSucc =
145+
isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
146+
bool TryCandIsClusterSucc =
147+
isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);
148+
149+
if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
150+
Cluster))
151+
return TryCand.Reason != NoCand;
152+
// Avoid critical resource consumption and balance the schedule.
153+
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
154+
TryCand, Cand, ResourceReduce))
155+
return TryCand.Reason != NoCand;
156+
if (tryGreater(TryCand.ResDelta.DemandedResources,
157+
Cand.ResDelta.DemandedResources, TryCand, Cand,
158+
ResourceDemand))
159+
return TryCand.Reason != NoCand;
160+
161+
// We only compare a subset of features when comparing nodes between
162+
// Top and Bottom boundary.
163+
if (Cand.AtTop == TryCand.AtTop) {
164+
// Avoid serializing long latency dependence chains.
165+
if (Cand.Policy.ReduceLatency &&
166+
tryLatency(TryCand, Cand, Cand.AtTop ? Top : Bot))
167+
return TryCand.Reason != NoCand;
168+
}
169+
170+
// Fall through to original instruction order.
171+
if (TryCand.SU->NodeNum < Cand.SU->NodeNum) {
172+
TryCand.Reason = NodeOrder;
173+
return true;
174+
}
175+
176+
return false;
177+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===-- AMDGPUMLSchedStrategy.h - ML-focused Scheduler Strategy -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// ML-focused scheduling strategy for AMDGPU.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "GCNSchedStrategy.h"
15+
#include "llvm/CodeGen/MachineScheduler.h"
16+
17+
namespace llvm {
18+
19+
class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
20+
protected:
21+
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
22+
SchedBoundary *Zone) const override;
23+
24+
public:
25+
AMDGPUMLSchedStrategy(const MachineSchedContext *C);
26+
};
27+
28+
class AMDGPUMLPostSchedStrategy : public PostGenericScheduler {
29+
protected:
30+
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand) override;
31+
32+
public:
33+
AMDGPUMLPostSchedStrategy(const MachineSchedContext *C);
34+
};
35+
36+
} // End namespace llvm

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "GCNPreRAOptimizations.h"
4444
#include "GCNRewritePartialRegUses.h"
4545
#include "GCNSchedStrategy.h"
46+
#include "AMDGPUMLSchedStrategy.h"
4647
#include "GCNVOPDUtils.h"
4748
#include "R600.h"
4849
#include "R600TargetMachine.h"
@@ -636,6 +637,11 @@ static ScheduleDAGInstrs *createSIMachineScheduler(MachineSchedContext *C) {
636637
return new SIScheduleDAGMI(C);
637638
}
638639

640+
static bool isMLWorkload(const Function &F) {
641+
Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
642+
return WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
643+
}
644+
639645
static ScheduleDAGInstrs *
640646
createGCNMaxOccupancyMachineScheduler(MachineSchedContext *C) {
641647
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -659,6 +665,11 @@ createGCNMaxILPMachineScheduler(MachineSchedContext *C) {
659665
return DAG;
660666
}
661667

668+
static ScheduleDAGInstrs *createGCNMLMachineScheduler(MachineSchedContext *C) {
669+
return new GCNScheduleDAGMILive(C,
670+
std::make_unique<AMDGPUMLSchedStrategy>(C));
671+
}
672+
662673
static ScheduleDAGInstrs *
663674
createGCNMaxMemoryClauseMachineScheduler(MachineSchedContext *C) {
664675
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -1170,6 +1181,9 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
11701181
if (ST.enableSIScheduler())
11711182
return createSIMachineScheduler(C);
11721183

1184+
if (isMLWorkload(C->MF->getFunction()))
1185+
return createGCNMLMachineScheduler(C);
1186+
11731187
Attribute SchedStrategyAttr =
11741188
C->MF->getFunction().getFnAttribute("amdgpu-sched-strategy");
11751189
StringRef SchedStrategy = SchedStrategyAttr.isValid()
@@ -1191,14 +1205,22 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
11911205
if (SchedStrategy == "iterative-maxocc")
11921206
return createIterativeGCNMaxOccupancyMachineScheduler(C);
11931207

1208+
if (SchedStrategy == "ml")
1209+
return createGCNMLMachineScheduler(C);
1210+
11941211
return createGCNMaxOccupancyMachineScheduler(C);
11951212
}
11961213

11971214
ScheduleDAGInstrs *
11981215
GCNTargetMachine::createPostMachineScheduler(MachineSchedContext *C) const {
1199-
ScheduleDAGMI *DAG =
1200-
new GCNPostScheduleDAGMILive(C, std::make_unique<PostGenericScheduler>(C),
1201-
/*RemoveKillFlags=*/true);
1216+
if (isMLWorkload(C->MF->getFunction()))
1217+
return new GCNPostScheduleDAGMILive(
1218+
C, std::make_unique<AMDGPUMLPostSchedStrategy>(C),
1219+
/*RemoveKillFlags=*/true);
1220+
1221+
ScheduleDAGMI *DAG = new GCNPostScheduleDAGMILive(
1222+
C, std::make_unique<PostGenericScheduler>(C),
1223+
/*RemoveKillFlags=*/true);
12021224
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
12031225
DAG->addMutation(createLoadClusterDAGMutation(DAG->TII, DAG->TRI));
12041226
if (ST.shouldClusterStores())

llvm/lib/Target/AMDGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ add_llvm_target(AMDGPUCodeGen
8989
AMDGPUMacroFusion.cpp
9090
AMDGPUMCInstLower.cpp
9191
AMDGPUMemoryUtils.cpp
92+
AMDGPUMLSchedStrategy.cpp
9293
AMDGPUIGroupLP.cpp
9394
AMDGPULowerVGPREncoding.cpp
9495
AMDGPUMCResourceInfo.cpp

llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ static bool canUsePressureDiffs(const SUnit &SU) {
184184
return true;
185185
}
186186

187-
static void getRegisterPressures(
187+
void GCNSchedStrategy::getRegisterPressures(
188188
bool AtTop, const RegPressureTracker &RPTracker, SUnit *SU,
189189
std::vector<unsigned> &Pressure, std::vector<unsigned> &MaxPressure,
190190
GCNDownwardRPTracker &DownwardTracker, GCNUpwardRPTracker &UpwardTracker,
191191
ScheduleDAGMI *DAG, const SIRegisterInfo *SRI) {
192192
// getDownwardPressure() and getUpwardPressure() make temporary changes to
193193
// the tracker, so we need to pass those function a non-const copy.
194194
RegPressureTracker &TempTracker = const_cast<RegPressureTracker &>(RPTracker);
195-
if (!GCNTrackers) {
195+
if (!useGCNTrackers()) {
196196
AtTop
197197
? TempTracker.getDownwardPressure(SU->getInstr(), Pressure, MaxPressure)
198198
: TempTracker.getUpwardPressure(SU->getInstr(), Pressure, MaxPressure);
@@ -244,7 +244,7 @@ void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
244244
//
245245
// In EXPENSIVE_CHECKS, we always query RPTracker to verify the results of
246246
// PressureDiffs.
247-
if (AtTop || !canUsePressureDiffs(*SU) || GCNTrackers) {
247+
if (AtTop || !canUsePressureDiffs(*SU) || useGCNTrackers()) {
248248
getRegisterPressures(AtTop, RPTracker, SU, Pressure, MaxPressure,
249249
DownwardTracker, UpwardTracker, DAG, SRI);
250250
} else {
@@ -388,7 +388,7 @@ void GCNSchedStrategy::pickNodeFromQueue(SchedBoundary &Zone,
388388
unsigned VGPRPressure = 0;
389389
IsPending = false;
390390
if (DAG->isTrackingPressure()) {
391-
if (!GCNTrackers) {
391+
if (!useGCNTrackers()) {
392392
SGPRPressure = Pressure[AMDGPU::RegisterPressureSets::SReg_32];
393393
VGPRPressure = Pressure[AMDGPU::RegisterPressureSets::VGPR_32];
394394
} else {
@@ -611,7 +611,7 @@ SUnit *GCNSchedStrategy::pickNode(bool &IsTopNode) {
611611
}
612612

613613
void GCNSchedStrategy::schedNode(SUnit *SU, bool IsTopNode) {
614-
if (GCNTrackers) {
614+
if (useGCNTrackers()) {
615615
MachineInstr *MI = SU->getInstr();
616616
IsTopNode ? (void)DownwardTracker.advance(MI, false)
617617
: UpwardTracker.recede(*MI);
@@ -693,7 +693,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
693693
SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
694694
SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
695695
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
696-
GCNTrackers = GCNTrackers & !IsLegacyScheduler;
696+
UseGCNTrackers = GCNTrackers & !IsLegacyScheduler;
697697
}
698698

699699
GCNMaxILPSchedStrategy::GCNMaxILPSchedStrategy(const MachineSchedContext *C)
@@ -1115,9 +1115,10 @@ void GCNScheduleDAGMILive::finalizeSchedule() {
11151115
void GCNScheduleDAGMILive::runSchedStages() {
11161116
LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
11171117

1118+
GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
11181119
if (!Regions.empty()) {
11191120
BBLiveInMap = getRegionLiveInMap();
1120-
if (GCNTrackers)
1121+
if (S.useGCNTrackers())
11211122
RegionLiveOuts.buildLiveRegMap();
11221123
}
11231124

@@ -1129,7 +1130,6 @@ void GCNScheduleDAGMILive::runSchedStages() {
11291130
}
11301131
#endif
11311132

1132-
GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
11331133
while (S.advanceStage()) {
11341134
auto Stage = createSchedStage(S.getCurrentStage());
11351135
if (!Stage->initGCNSchedStage())
@@ -1145,7 +1145,7 @@ void GCNScheduleDAGMILive::runSchedStages() {
11451145
continue;
11461146
}
11471147

1148-
if (GCNTrackers) {
1148+
if (S.useGCNTrackers()) {
11491149
GCNDownwardRPTracker *DownwardTracker = S.getDownwardTracker();
11501150
GCNUpwardRPTracker *UpwardTracker = S.getUpwardTracker();
11511151
GCNRPTracker::LiveRegSet *RegionLiveIns =
@@ -1294,7 +1294,7 @@ bool PreRARematStage::initGCNSchedStage() {
12941294

12951295
// Rematerialize identified instructions and update scheduler's state.
12961296
rematerialize();
1297-
if (GCNTrackers)
1297+
if (S.useGCNTrackers())
12981298
DAG.RegionLiveOuts.buildLiveRegMap();
12991299
REMAT_DEBUG({
13001300
dbgs() << "Retrying function scheduling with new min. occupancy of "

0 commit comments

Comments
 (0)