Skip to content

Commit 4d99b6d

Browse files
committed
[AMDGPU] Add scaffolding for ML focused scheduling strategy
This patch introduces scaffolding for a new machine instruction scheduling strategy optimized for machine learning workloads. Enable the ML scheduler automatically when functions have the "amdgpu-workload-type"="ml" attribute.
1 parent ef37858 commit 4d99b6d

File tree

8 files changed

+378
-10
lines changed

8 files changed

+378
-10
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===-- AMDGPUMLSchedStrategy.cpp - ML-focused Scheduler Strategy ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// ML-focused scheduling strategy for AMDGPU.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "AMDGPUMLSchedStrategy.h"
15+
16+
using namespace llvm;
17+
18+
AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
19+
: GCNSchedStrategy(C) {
20+
SchedStages.push_back(GCNSchedStageID::ILPInitialSchedule);
21+
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
22+
// Use more accurate GCN pressure trackers.
23+
UseGCNTrackers = true;
24+
}
25+
26+
void AMDGPUMLSchedStrategy::initialize(ScheduleDAGMI *DAG) {
27+
// ML scheduling strategy is only done top-down to support new resource
28+
// balancing heuristics.
29+
RegionPolicy.OnlyTopDown = true;
30+
RegionPolicy.OnlyBottomUp = false;
31+
32+
GCNSchedStrategy::initialize(DAG);
33+
}
34+
35+
bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
36+
SchedCandidate &TryCand,
37+
SchedBoundary *Zone) const {
38+
// Initialize the candidate if needed.
39+
if (!Cand.isValid()) {
40+
TryCand.Reason = FirstValid;
41+
return true;
42+
}
43+
44+
// Bias PhysReg Defs and copies to their uses and defined respectively.
45+
if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
46+
biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
47+
return TryCand.Reason != NoCand;
48+
49+
// Avoid exceeding the target's limit.
50+
if (DAG->isTrackingPressure() &&
51+
tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
52+
RegExcess, TRI, DAG->MF))
53+
return TryCand.Reason != NoCand;
54+
55+
// We only compare a subset of features when comparing nodes between
56+
// Top and Bottom boundary. Some properties are simply incomparable, in many
57+
// other instances we should only override the other boundary if something
58+
// is a clear good pick on one boundary. Skip heuristics that are more
59+
// "tie-breaking" in nature.
60+
bool SameBoundary = Zone != nullptr;
61+
if (SameBoundary) {
62+
// For loops that are acyclic path limited, aggressively schedule for
63+
// latency. Within an single cycle, whenever CurrMOps > 0, allow normal
64+
// heuristics to take precedence.
65+
if (Rem.IsAcyclicLatencyLimited && !Zone->getCurrMOps() &&
66+
tryLatency(TryCand, Cand, *Zone))
67+
return TryCand.Reason != NoCand;
68+
69+
// Prioritize instructions that read unbuffered resources by stall cycles.
70+
if (tryLess(Zone->getLatencyStallCycles(TryCand.SU),
71+
Zone->getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
72+
return TryCand.Reason != NoCand;
73+
}
74+
75+
// Keep clustered nodes together to encourage downstream peephole
76+
// optimizations which may reduce resource requirements.
77+
//
78+
// This is a best effort to set things up for a post-RA pass. Optimizations
79+
// like generating loads of multiple registers should ideally be done within
80+
// the scheduler pass by combining the loads during DAG postprocessing.
81+
unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
82+
unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
83+
bool CandIsClusterSucc =
84+
isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
85+
bool TryCandIsClusterSucc =
86+
isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);
87+
88+
if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
89+
Cluster))
90+
return TryCand.Reason != NoCand;
91+
92+
if (SameBoundary) {
93+
// Weak edges are for clustering and other constraints.
94+
if (tryLess(getWeakLeft(TryCand.SU, TryCand.AtTop),
95+
getWeakLeft(Cand.SU, Cand.AtTop), TryCand, Cand, Weak))
96+
return TryCand.Reason != NoCand;
97+
}
98+
99+
// Avoid increasing the max pressure of the entire region.
100+
if (DAG->isTrackingPressure() &&
101+
tryPressure(TryCand.RPDelta.CurrentMax, Cand.RPDelta.CurrentMax, TryCand,
102+
Cand, RegMax, TRI, DAG->MF))
103+
return TryCand.Reason != NoCand;
104+
105+
if (SameBoundary) {
106+
// Avoid critical resource consumption and balance the schedule.
107+
TryCand.initResourceDelta(DAG, SchedModel);
108+
if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
109+
TryCand, Cand, ResourceReduce))
110+
return TryCand.Reason != NoCand;
111+
if (tryGreater(TryCand.ResDelta.DemandedResources,
112+
Cand.ResDelta.DemandedResources, TryCand, Cand,
113+
ResourceDemand))
114+
return TryCand.Reason != NoCand;
115+
116+
// Avoid serializing long latency dependence chains.
117+
// For acyclic path limited loops, latency was already checked above.
118+
if (!RegionPolicy.DisableLatencyHeuristic && TryCand.Policy.ReduceLatency &&
119+
!Rem.IsAcyclicLatencyLimited && tryLatency(TryCand, Cand, *Zone))
120+
return TryCand.Reason != NoCand;
121+
122+
// Fall through to original instruction order.
123+
if ((Zone->isTop() && TryCand.SU->NodeNum < Cand.SU->NodeNum) ||
124+
(!Zone->isTop() && TryCand.SU->NodeNum > Cand.SU->NodeNum)) {
125+
TryCand.Reason = NodeOrder;
126+
return true;
127+
}
128+
}
129+
130+
return false;
131+
}
132+
133+
AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
134+
const MachineSchedContext *C)
135+
: PostGenericScheduler(C) {}
136+
137+
bool AMDGPUMLPostSchedStrategy::tryCandidate(SchedCandidate &Cand,
138+
SchedCandidate &TryCand) {
139+
// Initialize the candidate if needed.
140+
if (!Cand.isValid()) {
141+
TryCand.Reason = FirstValid;
142+
return true;
143+
}
144+
145+
// Fall through to original instruction order.
146+
// This effectively only enables hazard checking for post-RA scheduling.
147+
if (TryCand.SU->NodeNum < Cand.SU->NodeNum) {
148+
TryCand.Reason = NodeOrder;
149+
return true;
150+
}
151+
152+
return false;
153+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===-- AMDGPUMLSchedStrategy.h - ML-focused Scheduler Strategy -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file
10+
/// ML-focused scheduling strategy for AMDGPU.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
15+
#define LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
16+
17+
#include "GCNSchedStrategy.h"
18+
#include "llvm/CodeGen/MachineScheduler.h"
19+
20+
namespace llvm {
21+
22+
class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
23+
protected:
24+
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
25+
SchedBoundary *Zone) const override;
26+
27+
public:
28+
AMDGPUMLSchedStrategy(const MachineSchedContext *C);
29+
30+
void initialize(ScheduleDAGMI *DAG) override;
31+
};
32+
33+
class AMDGPUMLPostSchedStrategy : public PostGenericScheduler {
34+
protected:
35+
bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand) override;
36+
37+
public:
38+
AMDGPUMLPostSchedStrategy(const MachineSchedContext *C);
39+
};
40+
41+
} // End namespace llvm
42+
43+
#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "AMDGPUIGroupLP.h"
2525
#include "AMDGPUISelDAGToDAG.h"
2626
#include "AMDGPULowerVGPREncoding.h"
27+
#include "AMDGPUMLSchedStrategy.h"
2728
#include "AMDGPUMacroFusion.h"
2829
#include "AMDGPUPerfHintAnalysis.h"
2930
#include "AMDGPUPreloadKernArgProlog.h"
@@ -636,6 +637,11 @@ static ScheduleDAGInstrs *createSIMachineScheduler(MachineSchedContext *C) {
636637
return new SIScheduleDAGMI(C);
637638
}
638639

640+
static bool isMLWorkload(const Function &F) {
641+
Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
642+
return WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
643+
}
644+
639645
static ScheduleDAGInstrs *
640646
createGCNMaxOccupancyMachineScheduler(MachineSchedContext *C) {
641647
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -659,6 +665,11 @@ createGCNMaxILPMachineScheduler(MachineSchedContext *C) {
659665
return DAG;
660666
}
661667

668+
static ScheduleDAGInstrs *createGCNMLMachineScheduler(MachineSchedContext *C) {
669+
return new GCNScheduleDAGMILive(C,
670+
std::make_unique<AMDGPUMLSchedStrategy>(C));
671+
}
672+
662673
static ScheduleDAGInstrs *
663674
createGCNMaxMemoryClauseMachineScheduler(MachineSchedContext *C) {
664675
const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -1170,6 +1181,9 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
11701181
if (ST.enableSIScheduler())
11711182
return createSIMachineScheduler(C);
11721183

1184+
if (isMLWorkload(C->MF->getFunction()))
1185+
return createGCNMLMachineScheduler(C);
1186+
11731187
Attribute SchedStrategyAttr =
11741188
C->MF->getFunction().getFnAttribute("amdgpu-sched-strategy");
11751189
StringRef SchedStrategy = SchedStrategyAttr.isValid()
@@ -1191,11 +1205,19 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
11911205
if (SchedStrategy == "iterative-maxocc")
11921206
return createIterativeGCNMaxOccupancyMachineScheduler(C);
11931207

1208+
if (SchedStrategy == "ml")
1209+
return createGCNMLMachineScheduler(C);
1210+
11941211
return createGCNMaxOccupancyMachineScheduler(C);
11951212
}
11961213

11971214
ScheduleDAGInstrs *
11981215
GCNTargetMachine::createPostMachineScheduler(MachineSchedContext *C) const {
1216+
if (isMLWorkload(C->MF->getFunction()))
1217+
return new GCNPostScheduleDAGMILive(
1218+
C, std::make_unique<AMDGPUMLPostSchedStrategy>(C),
1219+
/*RemoveKillFlags=*/true);
1220+
11991221
ScheduleDAGMI *DAG =
12001222
new GCNPostScheduleDAGMILive(C, std::make_unique<PostGenericScheduler>(C),
12011223
/*RemoveKillFlags=*/true);

llvm/lib/Target/AMDGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ add_llvm_target(AMDGPUCodeGen
8989
AMDGPUMacroFusion.cpp
9090
AMDGPUMCInstLower.cpp
9191
AMDGPUMemoryUtils.cpp
92+
AMDGPUMLSchedStrategy.cpp
9293
AMDGPUIGroupLP.cpp
9394
AMDGPULowerVGPREncoding.cpp
9495
AMDGPUMCResourceInfo.cpp

llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ static bool canUsePressureDiffs(const SUnit &SU) {
184184
return true;
185185
}
186186

187-
static void getRegisterPressures(
187+
void GCNSchedStrategy::getRegisterPressures(
188188
bool AtTop, const RegPressureTracker &RPTracker, SUnit *SU,
189189
std::vector<unsigned> &Pressure, std::vector<unsigned> &MaxPressure,
190190
GCNDownwardRPTracker &DownwardTracker, GCNUpwardRPTracker &UpwardTracker,
191191
ScheduleDAGMI *DAG, const SIRegisterInfo *SRI) {
192192
// getDownwardPressure() and getUpwardPressure() make temporary changes to
193193
// the tracker, so we need to pass those function a non-const copy.
194194
RegPressureTracker &TempTracker = const_cast<RegPressureTracker &>(RPTracker);
195-
if (!GCNTrackers) {
195+
if (!useGCNTrackers()) {
196196
AtTop
197197
? TempTracker.getDownwardPressure(SU->getInstr(), Pressure, MaxPressure)
198198
: TempTracker.getUpwardPressure(SU->getInstr(), Pressure, MaxPressure);
@@ -244,7 +244,7 @@ void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
244244
//
245245
// In EXPENSIVE_CHECKS, we always query RPTracker to verify the results of
246246
// PressureDiffs.
247-
if (AtTop || !canUsePressureDiffs(*SU) || GCNTrackers) {
247+
if (AtTop || !canUsePressureDiffs(*SU) || useGCNTrackers()) {
248248
getRegisterPressures(AtTop, RPTracker, SU, Pressure, MaxPressure,
249249
DownwardTracker, UpwardTracker, DAG, SRI);
250250
} else {
@@ -388,7 +388,7 @@ void GCNSchedStrategy::pickNodeFromQueue(SchedBoundary &Zone,
388388
unsigned VGPRPressure = 0;
389389
IsPending = false;
390390
if (DAG->isTrackingPressure()) {
391-
if (!GCNTrackers) {
391+
if (!useGCNTrackers()) {
392392
SGPRPressure = Pressure[AMDGPU::RegisterPressureSets::SReg_32];
393393
VGPRPressure = Pressure[AMDGPU::RegisterPressureSets::VGPR_32];
394394
} else {
@@ -611,7 +611,7 @@ SUnit *GCNSchedStrategy::pickNode(bool &IsTopNode) {
611611
}
612612

613613
void GCNSchedStrategy::schedNode(SUnit *SU, bool IsTopNode) {
614-
if (GCNTrackers) {
614+
if (useGCNTrackers()) {
615615
MachineInstr *MI = SU->getInstr();
616616
IsTopNode ? (void)DownwardTracker.advance(MI, false)
617617
: UpwardTracker.recede(*MI);
@@ -693,7 +693,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
693693
SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
694694
SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
695695
SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
696-
GCNTrackers = GCNTrackers & !IsLegacyScheduler;
696+
UseGCNTrackers = GCNTrackers & !IsLegacyScheduler;
697697
}
698698

699699
GCNMaxILPSchedStrategy::GCNMaxILPSchedStrategy(const MachineSchedContext *C)
@@ -1118,9 +1118,10 @@ void GCNScheduleDAGMILive::finalizeSchedule() {
11181118
void GCNScheduleDAGMILive::runSchedStages() {
11191119
LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
11201120

1121+
GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
11211122
if (!Regions.empty()) {
11221123
BBLiveInMap = getRegionLiveInMap();
1123-
if (GCNTrackers)
1124+
if (S.useGCNTrackers())
11241125
RegionLiveOuts.buildLiveRegMap();
11251126
}
11261127

@@ -1132,7 +1133,6 @@ void GCNScheduleDAGMILive::runSchedStages() {
11321133
}
11331134
#endif
11341135

1135-
GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
11361136
while (S.advanceStage()) {
11371137
auto Stage = createSchedStage(S.getCurrentStage());
11381138
if (!Stage->initGCNSchedStage())
@@ -1148,7 +1148,7 @@ void GCNScheduleDAGMILive::runSchedStages() {
11481148
continue;
11491149
}
11501150

1151-
if (GCNTrackers) {
1151+
if (S.useGCNTrackers()) {
11521152
GCNDownwardRPTracker *DownwardTracker = S.getDownwardTracker();
11531153
GCNUpwardRPTracker *UpwardTracker = S.getUpwardTracker();
11541154
GCNRPTracker::LiveRegSet *RegionLiveIns =
@@ -1297,7 +1297,7 @@ bool PreRARematStage::initGCNSchedStage() {
12971297

12981298
// Rematerialize identified instructions and update scheduler's state.
12991299
rematerialize();
1300-
if (GCNTrackers)
1300+
if (S.useGCNTrackers())
13011301
DAG.RegionLiveOuts.buildLiveRegMap();
13021302
REMAT_DEBUG({
13031303
dbgs() << "Retrying function scheduling with new min. occupancy of "

llvm/lib/Target/AMDGPU/GCNSchedStrategy.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ class GCNSchedStrategy : public GenericScheduler {
7070
void printCandidateDecision(const SchedCandidate &Current,
7171
const SchedCandidate &Preferred);
7272

73+
void getRegisterPressures(bool AtTop, const RegPressureTracker &RPTracker,
74+
SUnit *SU, std::vector<unsigned> &Pressure,
75+
std::vector<unsigned> &MaxPressure,
76+
GCNDownwardRPTracker &DownwardTracker,
77+
GCNUpwardRPTracker &UpwardTracker,
78+
ScheduleDAGMI *DAG, const SIRegisterInfo *SRI);
79+
7380
std::vector<unsigned> Pressure;
7481

7582
std::vector<unsigned> MaxPressure;
@@ -94,6 +101,8 @@ class GCNSchedStrategy : public GenericScheduler {
94101
// GCN RP Tracker for botttom-up scheduling
95102
mutable GCNUpwardRPTracker UpwardTracker;
96103

104+
bool UseGCNTrackers = false;
105+
97106
public:
98107
// schedule() have seen register pressure over the critical limits and had to
99108
// track register pressure for actual scheduling heuristics.
@@ -141,6 +150,8 @@ class GCNSchedStrategy : public GenericScheduler {
141150

142151
bool hasNextStage() const;
143152

153+
bool useGCNTrackers() const { return UseGCNTrackers; }
154+
144155
GCNSchedStageID getNextStage() const;
145156

146157
GCNDownwardRPTracker *getDownwardTracker() { return &DownwardTracker; }

0 commit comments

Comments
 (0)