Skip to content

Commit 9a29353

Browse files
authored
AMDGPU: Handle multiple AGPR MFMA rewrites (#147975)
I have this firing on one of the real examples, need to produce the tests and check a few edge cases
1 parent 5544492 commit 9a29353

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,47 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5757
TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
5858
LIS(LIS) {}
5959

60+
// TODO: Remove this restriction
61+
bool mfmaHasSameSrc2AndDstReg(const MachineInstr &MI) const {
62+
const MachineOperand *Src2 = TII.getNamedOperand(MI, AMDGPU::OpName::src2);
63+
const MachineOperand *Dst = TII.getNamedOperand(MI, AMDGPU::OpName::vdst);
64+
return Src2->getReg() == Dst->getReg() &&
65+
Src2->getSubReg() == Dst->getSubReg();
66+
}
67+
68+
bool isRewriteCandidate(const MachineInstr &MI) const {
69+
return TII.isMAI(MI) &&
70+
AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1 &&
71+
mfmaHasSameSrc2AndDstReg(MI);
72+
}
73+
6074
/// Compute the register class constraints based on the uses of \p Reg,
61-
/// excluding uses from \p ExceptMI. This should be nearly identical to
75+
/// excluding MFMA uses from which can be rewritten to change the register
76+
/// class constraint. This should be nearly identical to
6277
/// MachineRegisterInfo::recomputeRegClass.
6378
const TargetRegisterClass *
64-
recomputeRegClassExcept(Register Reg, const TargetRegisterClass *OldRC,
65-
const TargetRegisterClass *NewRC,
66-
const MachineInstr *ExceptMI) const;
79+
recomputeRegClassExceptRewritable(Register Reg,
80+
const TargetRegisterClass *OldRC,
81+
const TargetRegisterClass *NewRC) const;
6782

6883
bool run(MachineFunction &MF) const;
6984
};
7085

7186
const TargetRegisterClass *
72-
AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExcept(
87+
AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExceptRewritable(
7388
Register Reg, const TargetRegisterClass *OldRC,
74-
const TargetRegisterClass *NewRC, const MachineInstr *ExceptMI) const {
89+
const TargetRegisterClass *NewRC) const {
7590

7691
// Accumulate constraints from all uses.
7792
for (MachineOperand &MO : MRI.reg_nodbg_operands(Reg)) {
7893
// Apply the effect of the given operand to NewRC.
7994
MachineInstr *MI = MO.getParent();
80-
if (MI == ExceptMI)
95+
96+
// We can swap the classes of dst + src2 as a pair to AGPR, so ignore the
97+
// effects of rewrite candidates. It just so happens that we can use either
98+
// AGPR or VGPR in src0/src1, so don't bother checking the constraint
99+
// effects of the individual operands.
100+
if (isRewriteCandidate(*MI))
81101
continue;
82102

83103
unsigned OpNo = &MO - &MI->getOperand(0);
@@ -190,10 +210,13 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
190210
// first place, as well as need to assign another register, and need to
191211
// figure out where to put them. The live range splitting is smarter than
192212
// anything we're doing here, so trust it did something reasonable.
193-
const TargetRegisterClass *Src2ExceptRC = recomputeRegClassExcept(
194-
Src2->getReg(), Src2VirtRegRC, VirtRegRC, CopySrcMI);
195-
if (!Src2ExceptRC)
213+
const TargetRegisterClass *Src2ExceptRC =
214+
recomputeRegClassExceptRewritable(Src2->getReg(), Src2VirtRegRC,
215+
VirtRegRC);
216+
if (!Src2ExceptRC) {
217+
LLVM_DEBUG(dbgs() << "Could not recompute the regclass\n");
196218
continue;
219+
}
197220

198221
const TargetRegisterClass *NewSrc2ConstraintRC =
199222
TII.getRegClass(TII.get(AGPROp), Src2->getOperandNo(), &TRI, MF);
@@ -203,8 +226,6 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
203226
const TargetRegisterClass *NewSrc2RC =
204227
TRI.getCommonSubClass(Src2ExceptRC, NewSrc2ConstraintRC);
205228
if (!NewSrc2RC) {
206-
// TODO: This is ignoring ther rewritable uses. e.g. a rewritable MFMA
207-
// using a rewritable MFMA can be rewritten as a pair.
208229
LLVM_DEBUG(dbgs() << "Other uses of " << printReg(Src2->getReg(), &TRI)
209230
<< " are incompatible with replacement class\n");
210231
continue;
@@ -215,8 +236,19 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
215236

216237
CopySrcMI->setDesc(TII.get(AGPROp));
217238

218-
// TODO: Is replacing too aggressive, fixup these instructions only?
219-
MRI.replaceRegWith(CopySrcReg, VReg);
239+
// Perform replacement of the register, rewriting the rewritable uses.
240+
for (MachineInstr &UseMI :
241+
make_early_inc_range(MRI.reg_instructions(CopySrcReg))) {
242+
if (TII.isMAI(UseMI)) {
243+
// Note the register we need to rewrite may still appear in src0/src1,
244+
// but that's fine since those can use A or V anyway.
245+
int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(UseMI.getOpcode());
246+
if (ReplacementOp != -1)
247+
UseMI.setDesc(TII.get(ReplacementOp));
248+
}
249+
250+
UseMI.substituteRegister(CopySrcReg, VReg, AMDGPU::NoSubRegister, TRI);
251+
}
220252

221253
LLVM_DEBUG(dbgs() << "Replaced VGPR MFMA with AGPR: " << *CopySrcMI);
222254

llvm/test/CodeGen/AMDGPU/inflate-reg-class-vgpr-mfma-to-av-with-load-source.mir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,15 @@ body: |
296296
; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000)
297297
; CHECK-NEXT: liveins: $vcc, $vgpr0_vgpr1
298298
; CHECK-NEXT: {{ $}}
299-
; CHECK-NEXT: renamable $vgpr2_vgpr3 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
300-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
301-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, killed $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
299+
; CHECK-NEXT: renamable $agpr0_agpr1 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
300+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
301+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, killed $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
302302
; CHECK-NEXT: S_CBRANCH_VCCNZ %bb.1, implicit $vcc
303303
; CHECK-NEXT: S_BRANCH %bb.2
304304
; CHECK-NEXT: {{ $}}
305305
; CHECK-NEXT: bb.2:
306-
; CHECK-NEXT: liveins: $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17:0x00000000FFFFFFFF
306+
; CHECK-NEXT: liveins: $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15:0x00000000FFFFFFFF
307307
; CHECK-NEXT: {{ $}}
308-
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = COPY killed renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17
309308
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7
310309
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15
311310
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23
@@ -384,16 +383,15 @@ body: |
384383
; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000)
385384
; CHECK-NEXT: liveins: $vcc, $vgpr0_vgpr1
386385
; CHECK-NEXT: {{ $}}
387-
; CHECK-NEXT: renamable $vgpr2_vgpr3 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
388-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
389-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 killed $vgpr4_vgpr5, $vgpr2_vgpr3, $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
386+
; CHECK-NEXT: renamable $agpr0_agpr1 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
387+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
388+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 killed $agpr2_agpr3, $agpr0_agpr1, $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
390389
; CHECK-NEXT: S_CBRANCH_VCCNZ %bb.1, implicit $vcc
391390
; CHECK-NEXT: S_BRANCH %bb.2
392391
; CHECK-NEXT: {{ $}}
393392
; CHECK-NEXT: bb.2:
394-
; CHECK-NEXT: liveins: $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17:0x00000000FFFFFFFF
393+
; CHECK-NEXT: liveins: $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15:0x00000000FFFFFFFF
395394
; CHECK-NEXT: {{ $}}
396-
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = COPY killed renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17
397395
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7
398396
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15
399397
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23
@@ -471,16 +469,15 @@ body: |
471469
; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000)
472470
; CHECK-NEXT: liveins: $vcc, $vgpr0_vgpr1
473471
; CHECK-NEXT: {{ $}}
474-
; CHECK-NEXT: renamable $vgpr2_vgpr3 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
475-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
476-
; CHECK-NEXT: renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17 = V_MFMA_F32_32X32X8F16_mac_vgprcd_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, killed $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17, 0, 0, 0, implicit $mode, implicit $exec
472+
; CHECK-NEXT: renamable $agpr0_agpr1 = GLOBAL_LOAD_DWORDX2 undef renamable $vgpr0_vgpr1, 0, 0, implicit $exec :: (load (s64), addrspace 1)
473+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
474+
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = V_MFMA_F32_32X32X8F16_mac_e64 $vgpr0_vgpr1, $vgpr0_vgpr1, killed $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15, 0, 0, 0, implicit $mode, implicit $exec
477475
; CHECK-NEXT: S_CBRANCH_VCCNZ %bb.1, implicit $vcc
478476
; CHECK-NEXT: S_BRANCH %bb.2
479477
; CHECK-NEXT: {{ $}}
480478
; CHECK-NEXT: bb.2:
481-
; CHECK-NEXT: liveins: $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17:0x00000000FFFFFFFF
479+
; CHECK-NEXT: liveins: $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15:0x00000000FFFFFFFF
482480
; CHECK-NEXT: {{ $}}
483-
; CHECK-NEXT: renamable $agpr0_agpr1_agpr2_agpr3_agpr4_agpr5_agpr6_agpr7_agpr8_agpr9_agpr10_agpr11_agpr12_agpr13_agpr14_agpr15 = COPY killed renamable $vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17
484481
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7
485482
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15
486483
; CHECK-NEXT: S_NOP 0, implicit-def $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23

0 commit comments

Comments
 (0)