Skip to content

Commit 5556890

Browse files
committed
[LLVM][AMDGPU] extend IGLP
1 parent b581bd3 commit 5556890

File tree

2 files changed

+219
-30
lines changed

2 files changed

+219
-30
lines changed

llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ class InstructionRule {
9494
std::optional<SmallVector<SUnit *, 4>> Cache;
9595

9696
public:
97-
virtual bool
98-
apply(const SUnit *, const ArrayRef<SUnit *>,
99-
SmallVectorImpl<SchedGroup> &) {
97+
virtual bool apply(const SUnit *, const ArrayRef<SUnit *>,
98+
SmallVectorImpl<SchedGroup> &) {
10099
return true;
101100
};
102101

@@ -696,6 +695,76 @@ bool PipelineSolver::solveExact() {
696695
return FinishedExploring;
697696
}
698697

698+
// Implement a IGLP scheduling strategy.
699+
class IGLPStrategy {
700+
protected:
701+
ScheduleDAGInstrs *DAG;
702+
703+
const SIInstrInfo *TII;
704+
705+
public:
706+
/// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
707+
virtual bool applyIGLPStrategy(
708+
DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
709+
DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
710+
AMDGPU::SchedulingPhase Phase) = 0;
711+
712+
// Returns true if this strategy should be applied to a ScheduleDAG.
713+
virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
714+
AMDGPU::SchedulingPhase Phase) = 0;
715+
716+
bool IsBottomUp = true;
717+
718+
IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
719+
: DAG(DAG), TII(TII) {}
720+
721+
virtual ~IGLPStrategy() = default;
722+
};
723+
724+
class MaxsOpt final : public IGLPStrategy {
725+
private:
726+
public:
727+
bool applyIGLPStrategy(
728+
DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
729+
DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
730+
AMDGPU::SchedulingPhase Phase) override;
731+
732+
bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
733+
AMDGPU::SchedulingPhase Phase) override {
734+
return true;
735+
}
736+
737+
MaxsOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
738+
: IGLPStrategy(DAG, TII) {
739+
IsBottomUp = true;
740+
}
741+
};
742+
743+
bool MaxsOpt::applyIGLPStrategy(
744+
DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
745+
DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
746+
AMDGPU::SchedulingPhase Phase) {
747+
// Count the number of MFMA instructions.
748+
unsigned MFMACount = 0;
749+
for (const MachineInstr &I : *DAG)
750+
if (TII->isMFMAorWMMA(I))
751+
++MFMACount;
752+
753+
const unsigned PipelineSyncID = 0;
754+
SchedGroup *SG = nullptr;
755+
for (unsigned I = 0; I < MFMACount * 3; ++I) {
756+
SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
757+
SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
758+
SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
759+
760+
SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
761+
SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
762+
SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
763+
}
764+
765+
return true;
766+
}
767+
699768
template <typename T>
700769
void PipelineSolver::greedyFind(
701770
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
@@ -815,33 +884,8 @@ enum IGLPStrategyID : int {
815884
MFMASmallGemmOptID = 0,
816885
MFMASmallGemmSingleWaveOptID = 1,
817886
MFMAExpInterleaveID = 2,
818-
MFMAExpSimpleInterleaveID = 3
819-
};
820-
821-
// Implement a IGLP scheduling strategy.
822-
class IGLPStrategy {
823-
protected:
824-
ScheduleDAGInstrs *DAG;
825-
826-
const SIInstrInfo *TII;
827-
828-
public:
829-
/// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
830-
virtual bool applyIGLPStrategy(
831-
DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
832-
DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
833-
AMDGPU::SchedulingPhase Phase) = 0;
834-
835-
// Returns true if this strategy should be applied to a ScheduleDAG.
836-
virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
837-
AMDGPU::SchedulingPhase Phase) = 0;
838-
839-
bool IsBottomUp = true;
840-
841-
IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
842-
: DAG(DAG), TII(TII) {}
843-
844-
virtual ~IGLPStrategy() = default;
887+
MFMAExpSimpleInterleaveID = 3,
888+
MaxsID = 4
845889
};
846890

847891
class MFMASmallGemmOpt final : public IGLPStrategy {
@@ -2335,6 +2379,8 @@ createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
23352379
return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
23362380
case MFMAExpSimpleInterleaveID:
23372381
return std::make_unique<MFMAExpSimpleInterleaveOpt>(DAG, TII);
2382+
case MaxsID:
2383+
return std::make_unique<MaxsOpt>(DAG, TII);
23382384
}
23392385

23402386
llvm_unreachable("Unknown IGLPStrategyID");
@@ -2599,10 +2645,14 @@ void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
25992645
}
26002646

26012647
if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2648+
// llvm::dbgs() << "before pipeline solver\n";
2649+
// DAG->dump();
26022650
PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
26032651
// PipelineSolver performs the mutation by adding the edges it
26042652
// determined as the best
26052653
PS.solve();
2654+
// llvm::dbgs() << "after pipeline solver\n";
2655+
// DAG->dump();
26062656
return;
26072657
}
26082658
}
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=amdgcn -mcpu=gfx90a -verify-machineinstrs < %s | FileCheck -check-prefix=GCN %s
3+
4+
define amdgpu_kernel void @test_iglp_opt_mfma_gemm(ptr addrspace(3) noalias %in, ptr addrspace(3) noalias %out) #0 {
5+
; GCN-LABEL: test_iglp_opt_mfma_gemm:
6+
; GCN: ; %bb.0: ; %entry
7+
; GCN-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x24
8+
; GCN-NEXT: v_lshlrev_b32_e32 v0, 7, v0
9+
; GCN-NEXT: v_and_b32_e32 v0, 0x1ff80, v0
10+
; GCN-NEXT: v_mov_b32_e32 v3, 2.0
11+
; GCN-NEXT: ; iglp_opt mask(0x00000000)
12+
; GCN-NEXT: s_waitcnt lgkmcnt(0)
13+
; GCN-NEXT: v_add_u32_e32 v1, s0, v0
14+
; GCN-NEXT: v_add_u32_e32 v2, 0x6000, v1
15+
; GCN-NEXT: ds_read_b128 a[28:31], v2 offset:57456
16+
; GCN-NEXT: ds_read_b128 a[24:27], v2 offset:57440
17+
; GCN-NEXT: ds_read_b128 a[20:23], v2 offset:57424
18+
; GCN-NEXT: ds_read_b128 a[16:19], v2 offset:57408
19+
; GCN-NEXT: ds_read_b128 a[0:3], v2 offset:57344
20+
; GCN-NEXT: ds_read_b128 a[4:7], v2 offset:57360
21+
; GCN-NEXT: ds_read_b128 a[8:11], v2 offset:57376
22+
; GCN-NEXT: ds_read_b128 a[12:15], v2 offset:57392
23+
; GCN-NEXT: v_mov_b32_e32 v2, 1.0
24+
; GCN-NEXT: ds_read_b128 a[60:63], v1 offset:49264
25+
; GCN-NEXT: ds_read_b128 a[56:59], v1 offset:49248
26+
; GCN-NEXT: ds_read_b128 a[52:55], v1 offset:49232
27+
; GCN-NEXT: ds_read_b128 a[48:51], v1 offset:49216
28+
; GCN-NEXT: ds_read_b128 a[44:47], v1 offset:49200
29+
; GCN-NEXT: ds_read_b128 a[40:43], v1 offset:49184
30+
; GCN-NEXT: ds_read_b128 a[36:39], v1 offset:49168
31+
; GCN-NEXT: ds_read_b128 a[32:35], v1 offset:49152
32+
; GCN-NEXT: s_waitcnt lgkmcnt(8)
33+
; GCN-NEXT: v_mfma_f32_32x32x1f32 a[0:31], v2, v3, a[0:31]
34+
; GCN-NEXT: ds_read_b128 a[156:159], v1 offset:112
35+
; GCN-NEXT: ds_read_b128 a[152:155], v1 offset:96
36+
; GCN-NEXT: ds_read_b128 a[68:71], v1 offset:24592
37+
; GCN-NEXT: ds_read_b128 a[64:67], v1 offset:24576
38+
; GCN-NEXT: v_add_u32_e32 v0, s1, v0
39+
; GCN-NEXT: s_waitcnt lgkmcnt(4)
40+
; GCN-NEXT: v_mfma_f32_32x32x1f32 a[32:63], v2, v3, a[32:63]
41+
; GCN-NEXT: ds_read_b128 a[148:151], v1 offset:80
42+
; GCN-NEXT: ds_read_b128 a[144:147], v1 offset:64
43+
; GCN-NEXT: ds_read_b128 a[128:131], v1
44+
; GCN-NEXT: ds_read_b128 a[132:135], v1 offset:16
45+
; GCN-NEXT: ds_read_b128 a[136:139], v1 offset:32
46+
; GCN-NEXT: ds_read_b128 a[140:143], v1 offset:48
47+
; GCN-NEXT: s_waitcnt lgkmcnt(0)
48+
; GCN-NEXT: v_mfma_f32_32x32x1f32 a[128:159], v2, v3, a[128:159]
49+
; GCN-NEXT: ds_read_b128 a[124:127], v1 offset:8304
50+
; GCN-NEXT: ds_read_b128 a[120:123], v1 offset:8288
51+
; GCN-NEXT: ds_read_b128 a[116:119], v1 offset:8272
52+
; GCN-NEXT: ds_read_b128 a[112:115], v1 offset:8256
53+
; GCN-NEXT: ds_read_b128 a[108:111], v1 offset:8240
54+
; GCN-NEXT: ds_read_b128 a[104:107], v1 offset:8224
55+
; GCN-NEXT: ds_read_b128 a[100:103], v1 offset:8208
56+
; GCN-NEXT: ds_read_b128 a[96:99], v1 offset:8192
57+
; GCN-NEXT: s_waitcnt lgkmcnt(0)
58+
; GCN-NEXT: v_mfma_f32_32x32x1f32 a[96:127], v2, v3, a[96:127]
59+
; GCN-NEXT: ds_read_b128 a[92:95], v1 offset:24688
60+
; GCN-NEXT: ds_read_b128 a[88:91], v1 offset:24672
61+
; GCN-NEXT: ds_read_b128 a[84:87], v1 offset:24656
62+
; GCN-NEXT: ds_read_b128 a[80:83], v1 offset:24640
63+
; GCN-NEXT: ds_read_b128 a[76:79], v1 offset:24624
64+
; GCN-NEXT: ds_read_b128 a[72:75], v1 offset:24608
65+
; GCN-NEXT: s_nop 2
66+
; GCN-NEXT: ds_write_b128 v0, a[156:159] offset:112
67+
; GCN-NEXT: ds_write_b128 v0, a[152:155] offset:96
68+
; GCN-NEXT: ds_write_b128 v0, a[148:151] offset:80
69+
; GCN-NEXT: ds_write_b128 v0, a[144:147] offset:64
70+
; GCN-NEXT: ds_write_b128 v0, a[140:143] offset:48
71+
; GCN-NEXT: ds_write_b128 v0, a[136:139] offset:32
72+
; GCN-NEXT: ds_write_b128 v0, a[132:135] offset:16
73+
; GCN-NEXT: ds_write_b128 v0, a[128:131]
74+
; GCN-NEXT: v_mov_b32_e32 v0, s1
75+
; GCN-NEXT: s_waitcnt lgkmcnt(8)
76+
; GCN-NEXT: v_mfma_f32_32x32x1f32 a[64:95], v2, v3, a[64:95]
77+
; GCN-NEXT: ds_write_b128 v0, a[56:59] offset:24672
78+
; GCN-NEXT: ds_write_b128 v0, a[60:63] offset:24688
79+
; GCN-NEXT: ds_write_b128 v0, a[48:51] offset:24640
80+
; GCN-NEXT: ds_write_b128 v0, a[120:123] offset:8288
81+
; GCN-NEXT: ds_write_b128 v0, a[124:127] offset:8304
82+
; GCN-NEXT: ds_write_b128 v0, a[112:115] offset:8256
83+
; GCN-NEXT: ds_write_b128 v0, a[116:119] offset:8272
84+
; GCN-NEXT: ds_write_b128 v0, a[104:107] offset:8224
85+
; GCN-NEXT: ds_write_b128 v0, a[108:111] offset:8240
86+
; GCN-NEXT: ds_write_b128 v0, a[96:99] offset:8192
87+
; GCN-NEXT: ds_write_b128 v0, a[100:103] offset:8208
88+
; GCN-NEXT: ds_write_b128 v0, a[52:55] offset:24656
89+
; GCN-NEXT: ds_write_b128 v0, a[40:43] offset:24608
90+
; GCN-NEXT: ds_write_b128 v0, a[44:47] offset:24624
91+
; GCN-NEXT: ds_write_b128 v0, a[32:35] offset:24576
92+
; GCN-NEXT: ds_write_b128 v0, a[36:39] offset:24592
93+
; GCN-NEXT: ds_write_b128 v0, a[24:27] offset:32864
94+
; GCN-NEXT: ds_write_b128 v0, a[28:31] offset:32880
95+
; GCN-NEXT: ds_write_b128 v0, a[16:19] offset:32832
96+
; GCN-NEXT: ds_write_b128 v0, a[88:91] offset:16480
97+
; GCN-NEXT: ds_write_b128 v0, a[92:95] offset:16496
98+
; GCN-NEXT: ds_write_b128 v0, a[80:83] offset:16448
99+
; GCN-NEXT: ds_write_b128 v0, a[84:87] offset:16464
100+
; GCN-NEXT: ds_write_b128 v0, a[72:75] offset:16416
101+
; GCN-NEXT: ds_write_b128 v0, a[76:79] offset:16432
102+
; GCN-NEXT: ds_write_b128 v0, a[64:67] offset:16384
103+
; GCN-NEXT: ds_write_b128 v0, a[68:71] offset:16400
104+
; GCN-NEXT: ds_write_b128 v0, a[20:23] offset:32848
105+
; GCN-NEXT: ds_write_b128 v0, a[8:11] offset:32800
106+
; GCN-NEXT: ds_write_b128 v0, a[12:15] offset:32816
107+
; GCN-NEXT: ds_write_b128 v0, a[0:3] offset:32768
108+
; GCN-NEXT: ds_write_b128 v0, a[4:7] offset:32784
109+
; GCN-NEXT: s_endpgm
110+
entry:
111+
call void @llvm.amdgcn.iglp.opt(i32 4)
112+
%idx = call i32 @llvm.amdgcn.workitem.id.x()
113+
%load.0.addr = getelementptr <32 x float>, ptr addrspace(3) %in, i32 %idx
114+
%load.0 = load <32 x float>, ptr addrspace(3) %load.0.addr
115+
%load.1.addr = getelementptr <32 x float>, ptr addrspace(3) %load.0.addr, i32 64
116+
%load.1 = load <32 x float>, ptr addrspace(3) %load.1.addr
117+
%load.2.addr = getelementptr <32 x float>, ptr addrspace(3) %load.1.addr, i32 128
118+
%load.2 = load <32 x float>, ptr addrspace(3) %load.2.addr
119+
%load.3.addr = getelementptr <32 x float>, ptr addrspace(3) %load.2.addr, i32 192
120+
%load.3 = load <32 x float>, ptr addrspace(3) %load.3.addr
121+
%load.4.addr = getelementptr <32 x float>, ptr addrspace(3) %load.3.addr, i32 256
122+
%load.4 = load <32 x float>, ptr addrspace(3) %load.4.addr
123+
%mai.0 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.0, i32 0, i32 0, i32 0)
124+
%mai.1 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.1, i32 0, i32 0, i32 0)
125+
%mai.2 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.2, i32 0, i32 0, i32 0)
126+
%mai.3 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.3, i32 0, i32 0, i32 0)
127+
%mai.4 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.4, i32 0, i32 0, i32 0)
128+
%store.0.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 %idx
129+
store <32 x float> %mai.0, ptr addrspace(3) %store.0.addr
130+
%store.1.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 64
131+
store <32 x float> %mai.1, ptr addrspace(3) %store.1.addr
132+
%store.2.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 128
133+
store <32 x float> %mai.2, ptr addrspace(3) %store.2.addr
134+
%store.3.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 192
135+
store <32 x float> %mai.3, ptr addrspace(3) %store.3.addr
136+
%store.4.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 256
137+
store <32 x float> %mai.4, ptr addrspace(3) %store.4.addr
138+
ret void
139+
}

0 commit comments

Comments
 (0)