Skip to content

Commit f2932c5

Browse files
committed
AMDGPU: Handle rewriting VGPR MFMA fed from AGPR copy
Previously we handled the inverse situation only.
1 parent c811f52 commit f2932c5

File tree

3 files changed

+249
-290
lines changed

3 files changed

+249
-290
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 191 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
/// MFMA opcode.
1515
///
1616
/// TODO:
17+
/// - Handle rewrites of phis. This must be more careful than normal about the
18+
/// reassignment. We do not want to introduce an AGPR-to-AGPR copy inside of a
19+
/// loop, so it depends on the exact assignment of the copy.
20+
///
1721
/// - Update LiveIntervals incrementally instead of recomputing from scratch
1822
///
1923
//===----------------------------------------------------------------------===//
@@ -60,6 +64,32 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
6064
return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
6165
}
6266

67+
/// Find AV_* registers assigned to AGPRs (or virtual registers which were
68+
/// already required to be AGPR).
69+
///
70+
/// \return the assigned physical register that \p VReg is assigned to if it
71+
/// is an AGPR, otherwise MCRegister().
72+
MCRegister getAssignedAGPR(Register VReg) const {
73+
MCRegister PhysReg = VRM.getPhys(VReg);
74+
if (!PhysReg)
75+
return MCRegister();
76+
77+
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
78+
if (!TRI.hasAGPRs(VirtRegRC))
79+
return MCRegister();
80+
81+
if (!TRI.hasVGPRs(VirtRegRC))
82+
return PhysReg;
83+
84+
// If this is an AV register, we have to check if the actual assignment is
85+
// to an AGPR
86+
const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
87+
return TRI.isAGPRClass(AssignedRC) ? PhysReg : MCRegister();
88+
}
89+
90+
bool tryReassigningMFMAChain(MachineInstr &MFMA, unsigned HintOpIdx,
91+
MCPhysReg PhysRegHint) const;
92+
6393
/// Compute the register class constraints based on the uses of \p Reg,
6494
/// excluding MFMA uses from which can be rewritten to change the register
6595
/// class constraint. This should be nearly identical to
@@ -74,6 +104,8 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
74104
Register Reg, SmallVectorImpl<MachineInstr *> &RewriteCandidates,
75105
SmallSetVector<Register, 4> &RewriteRegs) const;
76106

107+
bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const;
108+
bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const;
77109
bool run(MachineFunction &MF) const;
78110
};
79111

@@ -154,6 +186,88 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExceptRewritable(
154186
return true;
155187
}
156188

189+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryReassigningMFMAChain(
190+
MachineInstr &MFMA, unsigned HintOpIdx, MCPhysReg PhysRegHint) const {
191+
// src2 and dst have the same physical class constraint; try to preserve
192+
// the original src2 subclass if one were to exist.
193+
SmallVector<MachineInstr *, 4> RewriteCandidates = {&MFMA};
194+
SmallSetVector<Register, 4> RewriteRegs;
195+
196+
Register MFMAHintReg = MFMA.getOperand(HintOpIdx).getReg();
197+
// Make sure we reassign the MFMA we found the copy from first. We want
198+
// to ensure dst ends up in the physreg we were originally copying to.
199+
RewriteRegs.insert(MFMAHintReg);
200+
201+
// We've found av = COPY (MFMA), and need to verify that we can trivially
202+
// rewrite src2 to use the new AGPR. If we can't trivially replace it,
203+
// we're going to induce as many copies as we would have emitted in the
204+
// first place, as well as need to assign another register, and need to
205+
// figure out where to put them. The live range splitting is smarter than
206+
// anything we're doing here, so trust it did something reasonable.
207+
//
208+
// Note recomputeRegClassExceptRewritable will consider the constraints of
209+
// this MFMA's src2 as well as the src2/dst of any transitive MFMA users.
210+
if (!recomputeRegClassExceptRewritable(MFMAHintReg, RewriteCandidates,
211+
RewriteRegs)) {
212+
LLVM_DEBUG(dbgs() << "Could not recompute the regclass of dst reg "
213+
<< printReg(MFMAHintReg, &TRI) << '\n');
214+
return false;
215+
}
216+
217+
// If src2 and dst are different registers, we need to also reassign the
218+
// input to an available AGPR if it is compatible with all other uses.
219+
//
220+
// If we can't reassign it, we'd need to introduce a different copy
221+
// which is likely worse than the copy we'd be saving.
222+
//
223+
// It's likely that the MFMA is used in sequence with other MFMAs; if we
224+
// cannot migrate the full use/def chain of MFMAs, we would need to
225+
// introduce intermediate copies somewhere. So we only make the
226+
// transform if all the interfering MFMAs can also be migrated. Collect
227+
// the set of rewritable MFMAs and check if we can assign an AGPR at
228+
// that point.
229+
//
230+
// If any of the MFMAs aren't reassignable, we give up and rollback to
231+
// the original register assignments.
232+
233+
using RecoloringStack =
234+
SmallVector<std::pair<const LiveInterval *, MCRegister>, 8>;
235+
RecoloringStack TentativeReassignments;
236+
237+
for (Register RewriteReg : RewriteRegs) {
238+
LiveInterval &LI = LIS.getInterval(RewriteReg);
239+
TentativeReassignments.push_back({&LI, VRM.getPhys(RewriteReg)});
240+
LRM.unassign(LI);
241+
}
242+
243+
if (!attemptReassignmentsToAGPR(RewriteRegs, PhysRegHint)) {
244+
// Roll back the register assignments to the original state.
245+
for (auto [LI, OldAssign] : TentativeReassignments) {
246+
if (VRM.hasPhys(LI->reg()))
247+
LRM.unassign(*LI);
248+
LRM.assign(*LI, OldAssign);
249+
}
250+
251+
return false;
252+
}
253+
254+
// Fixup the register classes of the virtual registers now that we've
255+
// committed to the reassignments.
256+
for (Register InterferingReg : RewriteRegs) {
257+
const TargetRegisterClass *EquivalentAGPRRegClass =
258+
TRI.getEquivalentAGPRClass(MRI.getRegClass(InterferingReg));
259+
MRI.setRegClass(InterferingReg, EquivalentAGPRRegClass);
260+
}
261+
262+
for (MachineInstr *RewriteCandidate : RewriteCandidates) {
263+
int NewMFMAOp =
264+
AMDGPU::getMFMASrcCVDstAGPROp(RewriteCandidate->getOpcode());
265+
RewriteCandidate->setDesc(TII.get(NewMFMAOp));
266+
}
267+
268+
return true;
269+
}
270+
157271
/// Attempt to reassign the registers in \p InterferingRegs to be AGPRs, with a
158272
/// preference to use \p PhysReg first. Returns false if the reassignments
159273
/// cannot be trivially performed.
@@ -206,6 +320,78 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::attemptReassignmentsToAGPR(
206320
return true;
207321
}
208322

323+
/// Identify copies that look like:
324+
/// %vdst:vgpr = V_MFMA_.. %src0:av, %src1:av, %src2:vgpr
325+
/// %agpr = COPY %vgpr
326+
///
327+
/// Then try to replace the transitive uses of %src2 and %vdst with the AGPR
328+
/// versions of the MFMA. This should cover the common case.
329+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesToAGPR(
330+
Register VReg, MCRegister AssignedAGPR) const {
331+
bool MadeChange = false;
332+
for (MachineInstr &UseMI : MRI.def_instructions(VReg)) {
333+
if (!UseMI.isCopy())
334+
continue;
335+
336+
Register CopySrcReg = UseMI.getOperand(1).getReg();
337+
if (!CopySrcReg.isVirtual())
338+
continue;
339+
340+
// TODO: Handle loop phis copied to AGPR. e.g.
341+
//
342+
// loop:
343+
// %phi:vgpr = COPY %mfma:vgpr
344+
// %mfma:vgpr = V_MFMA_xxx_vgprcd_e64 %a, %b, %phi
345+
// s_cbranch_vccnz loop
346+
//
347+
// endloop:
348+
// %agpr = mfma
349+
//
350+
// We need to be sure that %phi is assigned to the same physical register as
351+
// %mfma, or else we will just be moving copies into the loop.
352+
353+
for (MachineInstr &CopySrcDefMI : MRI.def_instructions(CopySrcReg)) {
354+
if (isRewriteCandidate(CopySrcDefMI) &&
355+
tryReassigningMFMAChain(CopySrcDefMI, 0, AssignedAGPR))
356+
MadeChange = true;
357+
}
358+
}
359+
360+
return MadeChange;
361+
}
362+
363+
/// Identify copies that look like:
364+
/// %src:vgpr = COPY %src:agpr
365+
/// %vdst:vgpr = V_MFMA_... %src0:av, %src1:av, %src:vgpr
366+
///
367+
/// Then try to replace the transitive uses of %src2 and %vdst with the AGPR
368+
/// versions of the MFMA. This should cover rarer cases, and will generally be
369+
/// redundant with tryFoldCopiesToAGPR.
370+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
371+
Register VReg, MCRegister AssignedAGPR) const {
372+
bool MadeChange = false;
373+
for (MachineInstr &UseMI : MRI.use_instructions(VReg)) {
374+
if (!UseMI.isCopy())
375+
continue;
376+
377+
Register CopyDstReg = UseMI.getOperand(0).getReg();
378+
if (!CopyDstReg.isVirtual())
379+
continue;
380+
381+
for (MachineInstr &CopyUseMI : MRI.use_instructions(CopyDstReg)) {
382+
if (isRewriteCandidate(CopyUseMI)) {
383+
const MachineOperand *Op =
384+
CopyUseMI.findRegisterUseOperand(CopyDstReg, /*TRI=*/nullptr);
385+
if (tryReassigningMFMAChain(CopyUseMI, Op->getOperandNo(),
386+
VRM.getPhys(Op->getReg())))
387+
MadeChange = true;
388+
}
389+
}
390+
}
391+
392+
return MadeChange;
393+
}
394+
209395
bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
210396
// This only applies on subtargets that have a configurable AGPR vs. VGPR
211397
// allocation.
@@ -222,124 +408,14 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
222408

223409
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
224410
Register VReg = Register::index2VirtReg(I);
225-
Register PhysReg = VRM.getPhys(VReg);
226-
if (!PhysReg)
227-
continue;
228-
229-
// Find AV_* registers assigned to AGPRs.
230-
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
231-
if (!TRI.hasAGPRs(VirtRegRC))
411+
MCRegister AssignedAGPR = getAssignedAGPR(VReg);
412+
if (!AssignedAGPR)
232413
continue;
233414

234-
const TargetRegisterClass *AssignedRC = VirtRegRC;
235-
if (TRI.hasVGPRs(VirtRegRC)) {
236-
// If this is an AV register, we have to check if the actual assignment is
237-
// to an AGPR
238-
AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
239-
if (!TRI.isAGPRClass(AssignedRC))
240-
continue;
241-
}
242-
243-
LiveInterval &LI = LIS.getInterval(VReg);
244-
245-
for (VNInfo *VNI : LI.vnis()) {
246-
if (VNI->isPHIDef() || VNI->isUnused())
247-
continue;
248-
249-
MachineInstr *DefMI = LIS.getInstructionFromIndex(VNI->def);
250-
if (!DefMI || !DefMI->isCopy())
251-
continue;
252-
253-
Register MFMADstReg = DefMI->getOperand(1).getReg();
254-
if (!MFMADstReg.isVirtual())
255-
continue;
256-
257-
LiveInterval &CopySrcLI = LIS.getInterval(MFMADstReg);
258-
LiveQueryResult LRQ = CopySrcLI.Query(VNI->def.getRegSlot());
259-
MachineInstr *MFMA = LIS.getInstructionFromIndex(LRQ.valueIn()->def);
260-
if (!MFMA || !isRewriteCandidate(*MFMA))
261-
continue;
262-
263-
// src2 and dst have the same physical class constraint; try to preserve
264-
// the original src2 subclass if one were to exist.
265-
SmallVector<MachineInstr *, 4> RewriteCandidates = {MFMA};
266-
SmallSetVector<Register, 4> RewriteRegs;
267-
268-
// Make sure we reassign the MFMA we found the copy from first. We want
269-
// to ensure dst ends up in the physreg we were originally copying to.
270-
RewriteRegs.insert(MFMADstReg);
271-
272-
// We've found av = COPY (MFMA), and need to verify that we can trivially
273-
// rewrite src2 to use the new AGPR. If we can't trivially replace it,
274-
// we're going to induce as many copies as we would have emitted in the
275-
// first place, as well as need to assign another register, and need to
276-
// figure out where to put them. The live range splitting is smarter than
277-
// anything we're doing here, so trust it did something reasonable.
278-
//
279-
// Note recomputeRegClassExceptRewritable will consider the constraints of
280-
// this MFMA's src2 as well as the src2/dst of any transitive MFMA users.
281-
if (!recomputeRegClassExceptRewritable(MFMADstReg, RewriteCandidates,
282-
RewriteRegs)) {
283-
LLVM_DEBUG(dbgs() << "Could not recompute the regclass of dst reg "
284-
<< printReg(MFMADstReg, &TRI) << '\n');
285-
continue;
286-
}
287-
288-
// If src2 and dst are different registers, we need to also reassign the
289-
// input to an available AGPR if it is compatible with all other uses.
290-
//
291-
// If we can't reassign it, we'd need to introduce a different copy
292-
// which is likely worse than the copy we'd be saving.
293-
//
294-
// It's likely that the MFMA is used in sequence with other MFMAs; if we
295-
// cannot migrate the full use/def chain of MFMAs, we would need to
296-
// introduce intermediate copies somewhere. So we only make the
297-
// transform if all the interfering MFMAs can also be migrated. Collect
298-
// the set of rewritable MFMAs and check if we can assign an AGPR at
299-
// that point.
300-
//
301-
// If any of the MFMAs aren't reassignable, we give up and rollback to
302-
// the original register assignments.
303-
304-
using RecoloringStack =
305-
SmallVector<std::pair<const LiveInterval *, MCRegister>, 8>;
306-
RecoloringStack TentativeReassignments;
307-
308-
for (Register RewriteReg : RewriteRegs) {
309-
LiveInterval &LI = LIS.getInterval(RewriteReg);
310-
TentativeReassignments.push_back({&LI, VRM.getPhys(RewriteReg)});
311-
LRM.unassign(LI);
312-
}
313-
314-
if (!attemptReassignmentsToAGPR(RewriteRegs, PhysReg)) {
315-
// Roll back the register assignments to the original state.
316-
for (auto [LI, OldAssign] : TentativeReassignments) {
317-
if (VRM.hasPhys(LI->reg()))
318-
LRM.unassign(*LI);
319-
LRM.assign(*LI, OldAssign);
320-
}
321-
322-
continue;
323-
}
324-
325-
// Fixup the register classes of the virtual registers now that we've
326-
// committed to the reassignments.
327-
for (Register InterferingReg : RewriteRegs) {
328-
const TargetRegisterClass *EquivalentAGPRRegClass =
329-
TRI.getEquivalentAGPRClass(MRI.getRegClass(InterferingReg));
330-
MRI.setRegClass(InterferingReg, EquivalentAGPRRegClass);
331-
}
332-
333-
for (MachineInstr *RewriteCandidate : RewriteCandidates) {
334-
int NewMFMAOp =
335-
AMDGPU::getMFMASrcCVDstAGPROp(RewriteCandidate->getOpcode());
336-
RewriteCandidate->setDesc(TII.get(NewMFMAOp));
337-
}
338-
339-
// We likely left an identity copy behind after assignment; let
340-
// VirtRegRewriter deal with it later.
415+
if (tryFoldCopiesToAGPR(VReg, AssignedAGPR))
416+
MadeChange = true;
417+
if (tryFoldCopiesFromAGPR(VReg, AssignedAGPR))
341418
MadeChange = true;
342-
}
343419
}
344420

345421
return MadeChange;

0 commit comments

Comments
 (0)