Skip to content

Commit babdad3

Browse files
authored
AMDGPU: Try to unspill VGPRs after rewriting MFMAs to AGPR form (#154323)
After replacing VGPR MFMAs with the AGPR form, we've alleviated VGPR pressure which may have triggered spills during allocation. Identify these spill slots, and try to reassign them to newly freed VGPRs, and replace the spill instructions with copies. Fixes #154260
1 parent f017bcb commit babdad3

File tree

2 files changed

+178
-41
lines changed

2 files changed

+178
-41
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 171 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "llvm/ADT/Statistic.h"
3030
#include "llvm/CodeGen/LiveIntervals.h"
3131
#include "llvm/CodeGen/LiveRegMatrix.h"
32+
#include "llvm/CodeGen/LiveStacks.h"
33+
#include "llvm/CodeGen/MachineFrameInfo.h"
3234
#include "llvm/CodeGen/MachineFunctionPass.h"
3335
#include "llvm/CodeGen/VirtRegMap.h"
3436
#include "llvm/InitializePasses.h"
@@ -42,6 +44,9 @@ namespace {
4244
STATISTIC(NumMFMAsRewrittenToAGPR,
4345
"Number of MFMA instructions rewritten to use AGPR form");
4446

47+
/// Map from spill slot frame index to list of instructions which reference it.
48+
using SpillReferenceMap = DenseMap<int, SmallVector<MachineInstr *, 4>>;
49+
4550
class AMDGPURewriteAGPRCopyMFMAImpl {
4651
MachineFunction &MF;
4752
const GCNSubtarget &ST;
@@ -51,6 +56,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5156
VirtRegMap &VRM;
5257
LiveRegMatrix &LRM;
5358
LiveIntervals &LIS;
59+
LiveStacks &LSS;
5460
const RegisterClassInfo &RegClassInfo;
5561

5662
bool attemptReassignmentsToAGPR(SmallSetVector<Register, 4> &InterferingRegs,
@@ -59,10 +65,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5965
public:
6066
AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
6167
LiveRegMatrix &LRM, LiveIntervals &LIS,
68+
LiveStacks &LSS,
6269
const RegisterClassInfo &RegClassInfo)
6370
: MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
6471
TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
65-
LIS(LIS), RegClassInfo(RegClassInfo) {}
72+
LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
6673

6774
bool isRewriteCandidate(const MachineInstr &MI) const {
6875
return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
@@ -103,6 +110,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
103110

104111
bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const;
105112
bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const;
113+
114+
/// Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI
115+
/// with a COPY to the replacement register value \p VReg.
116+
void replaceSpillWithCopyToVReg(MachineInstr &SpillMI, int SpillFI,
117+
Register VReg) const;
118+
119+
/// Create a map from frame index to use instructions for spills. If a use of
120+
/// the frame index does not consist only of spill instructions, it will not
121+
/// be included in the map.
122+
void collectSpillIndexUses(ArrayRef<LiveInterval *> StackIntervals,
123+
SpillReferenceMap &Map) const;
124+
125+
/// Attempt to unspill VGPRs by finding a free register and replacing the
126+
/// spill instructions with copies.
127+
void eliminateSpillsOfReassignedVGPRs() const;
128+
106129
bool run(MachineFunction &MF) const;
107130
};
108131

@@ -391,6 +414,138 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
391414
return MadeChange;
392415
}
393416

417+
void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg(
418+
MachineInstr &SpillMI, int SpillFI, Register VReg) const {
419+
const DebugLoc &DL = SpillMI.getDebugLoc();
420+
MachineBasicBlock &MBB = *SpillMI.getParent();
421+
MachineInstr *NewCopy;
422+
if (SpillMI.mayStore()) {
423+
NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY), VReg)
424+
.add(SpillMI.getOperand(0));
425+
} else {
426+
NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY))
427+
.add(SpillMI.getOperand(0))
428+
.addReg(VReg);
429+
}
430+
431+
LIS.ReplaceMachineInstrInMaps(SpillMI, *NewCopy);
432+
SpillMI.eraseFromParent();
433+
}
434+
435+
void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses(
436+
ArrayRef<LiveInterval *> StackIntervals, SpillReferenceMap &Map) const {
437+
438+
SmallSet<int, 4> NeededFrameIndexes;
439+
for (const LiveInterval *LI : StackIntervals)
440+
NeededFrameIndexes.insert(LI->reg().stackSlotIndex());
441+
442+
for (MachineBasicBlock &MBB : MF) {
443+
for (MachineInstr &MI : MBB) {
444+
for (MachineOperand &MO : MI.operands()) {
445+
if (!MO.isFI() || !NeededFrameIndexes.count(MO.getIndex()))
446+
continue;
447+
448+
if (TII.isVGPRSpill(MI)) {
449+
SmallVector<MachineInstr *, 4> &References = Map[MO.getIndex()];
450+
References.push_back(&MI);
451+
break;
452+
}
453+
454+
// Verify this was really a spill instruction, if it's not just ignore
455+
// all uses.
456+
457+
// TODO: This should probably be verifier enforced.
458+
NeededFrameIndexes.erase(MO.getIndex());
459+
Map.erase(MO.getIndex());
460+
}
461+
}
462+
}
463+
}
464+
465+
void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs() const {
466+
unsigned NumSlots = LSS.getNumIntervals();
467+
if (NumSlots == 0)
468+
return;
469+
470+
MachineFrameInfo &MFI = MF.getFrameInfo();
471+
472+
SmallVector<LiveInterval *, 32> StackIntervals;
473+
StackIntervals.reserve(NumSlots);
474+
475+
for (auto &[Slot, LI] : LSS) {
476+
if (!MFI.isSpillSlotObjectIndex(Slot) || MFI.isDeadObjectIndex(Slot))
477+
continue;
478+
479+
const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
480+
if (TRI.hasVGPRs(RC))
481+
StackIntervals.push_back(&LI);
482+
}
483+
484+
sort(StackIntervals, [](const LiveInterval *A, const LiveInterval *B) {
485+
/// Sort heaviest intervals first to prioritize their unspilling
486+
if (A->weight() > B->weight())
487+
return true;
488+
489+
if (A->getSize() > B->getSize())
490+
return true;
491+
492+
// Tie breaker by number to avoid need for stable sort
493+
return A->reg().stackSlotIndex() < B->reg().stackSlotIndex();
494+
});
495+
496+
// FIXME: The APIs for dealing with the LiveInterval of a frame index are
497+
// cumbersome. LiveStacks owns its LiveIntervals which refer to stack
498+
// slots. We cannot use the usual LiveRegMatrix::assign and unassign on these,
499+
// and must create a substitute virtual register to do so. This makes
500+
// incremental updating here difficult; we need to actually perform the IR
501+
// mutation to get the new vreg references in place to compute the register
502+
// LiveInterval to perform an assignment to track the new interference
503+
// correctly, and we can't simply migrate the LiveInterval we already have.
504+
//
505+
// To avoid walking through the entire function for each index, pre-collect
506+
// all the instructions slot referencess.
507+
508+
DenseMap<int, SmallVector<MachineInstr *, 4>> SpillSlotReferences;
509+
collectSpillIndexUses(StackIntervals, SpillSlotReferences);
510+
511+
for (LiveInterval *LI : StackIntervals) {
512+
int Slot = LI->reg().stackSlotIndex();
513+
auto SpillReferences = SpillSlotReferences.find(Slot);
514+
if (SpillReferences == SpillSlotReferences.end())
515+
continue;
516+
517+
const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
518+
519+
LLVM_DEBUG(dbgs() << "Trying to eliminate " << printReg(Slot, &TRI)
520+
<< " by reassigning\n");
521+
522+
ArrayRef<MCPhysReg> AllocOrder = RegClassInfo.getOrder(RC);
523+
524+
for (MCPhysReg PhysReg : AllocOrder) {
525+
if (LRM.checkInterference(*LI, PhysReg) != LiveRegMatrix::IK_Free)
526+
continue;
527+
528+
LLVM_DEBUG(dbgs() << "Reassigning " << *LI << " to "
529+
<< printReg(PhysReg, &TRI) << '\n');
530+
531+
const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
532+
Register NewVReg = MRI.createVirtualRegister(RC);
533+
534+
for (MachineInstr *SpillMI : SpillReferences->second)
535+
replaceSpillWithCopyToVReg(*SpillMI, Slot, NewVReg);
536+
537+
// TODO: We should be able to transfer the information from the stack
538+
// slot's LiveInterval without recomputing from scratch with the
539+
// replacement vreg uses.
540+
LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval(NewVReg);
541+
VRM.grow();
542+
LRM.assign(NewLI, PhysReg);
543+
MFI.RemoveStackObject(Slot);
544+
break;
545+
}
546+
}
547+
}
548+
394549
bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
395550
// This only applies on subtargets that have a configurable AGPR vs. VGPR
396551
// allocation.
@@ -417,6 +572,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
417572
MadeChange = true;
418573
}
419574

575+
// If we've successfully rewritten some MFMAs, we've alleviated some VGPR
576+
// pressure. See if we can eliminate some spills now that those registers are
577+
// more available.
578+
if (MadeChange)
579+
eliminateSpillsOfReassignedVGPRs();
580+
420581
return MadeChange;
421582
}
422583

@@ -440,10 +601,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
440601
AU.addRequired<LiveIntervalsWrapperPass>();
441602
AU.addRequired<VirtRegMapWrapperLegacy>();
442603
AU.addRequired<LiveRegMatrixWrapperLegacy>();
604+
AU.addRequired<LiveStacksWrapperLegacy>();
443605

444606
AU.addPreserved<LiveIntervalsWrapperPass>();
445607
AU.addPreserved<VirtRegMapWrapperLegacy>();
446608
AU.addPreserved<LiveRegMatrixWrapperLegacy>();
609+
AU.addPreserved<LiveStacksWrapperLegacy>();
610+
447611
AU.setPreservesAll();
448612
MachineFunctionPass::getAnalysisUsage(AU);
449613
}
@@ -456,6 +620,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
456620
INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
457621
INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
458622
INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
623+
INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
459624
INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
460625
"AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
461626

@@ -474,8 +639,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
474639
auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
475640
auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
476641
auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
477-
478-
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
642+
auto &LSS = getAnalysis<LiveStacksWrapperLegacy>().getLS();
643+
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
479644
return Impl.run(MF);
480645
}
481646

@@ -485,13 +650,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
485650
VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(MF);
486651
LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(MF);
487652
LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(MF);
653+
LiveStacks &LSS = MFAM.getResult<LiveStacksAnalysis>(MF);
488654
RegisterClassInfo RegClassInfo;
489655
RegClassInfo.runOnMachineFunction(MF);
490656

491-
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo);
657+
AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
492658
if (!Impl.run(MF))
493659
return PreservedAnalyses::all();
494660
auto PA = getMachineFunctionPassPreservedAnalyses();
495661
PA.preserveSet<CFGAnalyses>();
662+
PA.preserve<LiveStacksAnalysis>();
496663
return PA;
497664
}

llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,8 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
101101
; CHECK-NEXT: v_accvgpr_read_b32 v2, a2
102102
; CHECK-NEXT: v_accvgpr_read_b32 v3, a3
103103
; CHECK-NEXT: ;;#ASMSTART
104-
; CHECK-NEXT: ; def v[0:3]
104+
; CHECK-NEXT: ; def v[10:13]
105105
; CHECK-NEXT: ;;#ASMEND
106-
; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
107-
; CHECK-NEXT: s_nop 0
108-
; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
109-
; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
110-
; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
111106
; CHECK-NEXT: v_mov_b32_e32 v0, 0
112107
; CHECK-NEXT: ;;#ASMSTART
113108
; CHECK-NEXT: ; def a[0:31]
@@ -147,12 +142,7 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg,
147142
; CHECK-NEXT: s_waitcnt vmcnt(0)
148143
; CHECK-NEXT: global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
149144
; CHECK-NEXT: s_waitcnt vmcnt(0)
150-
; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
151-
; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
152-
; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
153-
; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
154-
; CHECK-NEXT: s_waitcnt vmcnt(0)
155-
; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17]
145+
; CHECK-NEXT: global_store_dwordx4 v0, v[10:13], s[16:17]
156146
; CHECK-NEXT: s_waitcnt vmcnt(0)
157147
; CHECK-NEXT: buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
158148
; CHECK-NEXT: buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload
@@ -311,26 +301,16 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
311301
; CHECK-NEXT: v_accvgpr_write_b32 a33, v1
312302
; CHECK-NEXT: v_accvgpr_write_b32 a32, v0
313303
; CHECK-NEXT: v_accvgpr_read_b32 v7, a3
304+
; CHECK-NEXT: v_mov_b32_e32 v0, 0
314305
; CHECK-NEXT: v_accvgpr_read_b32 v6, a2
315306
; CHECK-NEXT: v_accvgpr_read_b32 v5, a1
316307
; CHECK-NEXT: v_accvgpr_read_b32 v4, a0
317308
; CHECK-NEXT: ;;#ASMSTART
318-
; CHECK-NEXT: ; def v[0:3]
309+
; CHECK-NEXT: ; def v[10:13]
319310
; CHECK-NEXT: ;;#ASMEND
320-
; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill
321-
; CHECK-NEXT: s_nop 0
322-
; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill
323-
; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill
324-
; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill
325311
; CHECK-NEXT: ;;#ASMSTART
326-
; CHECK-NEXT: ; def v[0:3]
312+
; CHECK-NEXT: ; def v[14:17]
327313
; CHECK-NEXT: ;;#ASMEND
328-
; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:208 ; 4-byte Folded Spill
329-
; CHECK-NEXT: s_nop 0
330-
; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:212 ; 4-byte Folded Spill
331-
; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:216 ; 4-byte Folded Spill
332-
; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:220 ; 4-byte Folded Spill
333-
; CHECK-NEXT: v_mov_b32_e32 v0, 0
334314
; CHECK-NEXT: ;;#ASMSTART
335315
; CHECK-NEXT: ; def a[0:31]
336316
; CHECK-NEXT: ;;#ASMEND
@@ -369,19 +349,9 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar
369349
; CHECK-NEXT: s_waitcnt vmcnt(0)
370350
; CHECK-NEXT: global_store_dwordx4 v0, a[36:39], s[16:17] offset:16
371351
; CHECK-NEXT: s_waitcnt vmcnt(0)
372-
; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload
373-
; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload
374-
; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload
375-
; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload
376-
; CHECK-NEXT: s_waitcnt vmcnt(0)
377-
; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17]
378-
; CHECK-NEXT: s_waitcnt vmcnt(0)
379-
; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:208 ; 4-byte Folded Reload
380-
; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:212 ; 4-byte Folded Reload
381-
; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:216 ; 4-byte Folded Reload
382-
; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:220 ; 4-byte Folded Reload
352+
; CHECK-NEXT: global_store_dwordx4 v0, v[10:13], s[16:17]
383353
; CHECK-NEXT: s_waitcnt vmcnt(0)
384-
; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17]
354+
; CHECK-NEXT: global_store_dwordx4 v0, v[14:17], s[16:17]
385355
; CHECK-NEXT: s_waitcnt vmcnt(0)
386356
; CHECK-NEXT: buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload
387357
; CHECK-NEXT: buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload

0 commit comments

Comments
 (0)