2929#include " llvm/ADT/Statistic.h"
3030#include " llvm/CodeGen/LiveIntervals.h"
3131#include " llvm/CodeGen/LiveRegMatrix.h"
32+ #include " llvm/CodeGen/LiveStacks.h"
33+ #include " llvm/CodeGen/MachineFrameInfo.h"
3234#include " llvm/CodeGen/MachineFunctionPass.h"
3335#include " llvm/CodeGen/VirtRegMap.h"
3436#include " llvm/InitializePasses.h"
@@ -42,6 +44,9 @@ namespace {
4244STATISTIC (NumMFMAsRewrittenToAGPR,
4345 " Number of MFMA instructions rewritten to use AGPR form" );
4446
47+ // / Map from spill slot frame index to list of instructions which reference it.
48+ using SpillReferenceMap = DenseMap<int , SmallVector<MachineInstr *, 4 >>;
49+
4550class AMDGPURewriteAGPRCopyMFMAImpl {
4651 MachineFunction &MF;
4752 const GCNSubtarget &ST;
@@ -51,6 +56,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5156 VirtRegMap &VRM;
5257 LiveRegMatrix &LRM;
5358 LiveIntervals &LIS;
59+ LiveStacks &LSS;
5460 const RegisterClassInfo &RegClassInfo;
5561
5662 bool attemptReassignmentsToAGPR (SmallSetVector<Register, 4 > &InterferingRegs,
@@ -59,10 +65,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
5965public:
6066 AMDGPURewriteAGPRCopyMFMAImpl (MachineFunction &MF, VirtRegMap &VRM,
6167 LiveRegMatrix &LRM, LiveIntervals &LIS,
68+ LiveStacks &LSS,
6269 const RegisterClassInfo &RegClassInfo)
6370 : MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
6471 TRI (*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
65- LIS(LIS), RegClassInfo(RegClassInfo) {}
72+ LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
6673
6774 bool isRewriteCandidate (const MachineInstr &MI) const {
6875 return TII.isMAI (MI) && AMDGPU::getMFMASrcCVDstAGPROp (MI.getOpcode ()) != -1 ;
@@ -103,6 +110,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
103110
104111 bool tryFoldCopiesToAGPR (Register VReg, MCRegister AssignedAGPR) const ;
105112 bool tryFoldCopiesFromAGPR (Register VReg, MCRegister AssignedAGPR) const ;
113+
114+ // / Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI
115+ // / with a COPY to the replacement register value \p VReg.
116+ void replaceSpillWithCopyToVReg (MachineInstr &SpillMI, int SpillFI,
117+ Register VReg) const ;
118+
119+ // / Create a map from frame index to use instructions for spills. If a use of
120+ // / the frame index does not consist only of spill instructions, it will not
121+ // / be included in the map.
122+ void collectSpillIndexUses (ArrayRef<LiveInterval *> StackIntervals,
123+ SpillReferenceMap &Map) const ;
124+
125+ // / Attempt to unspill VGPRs by finding a free register and replacing the
126+ // / spill instructions with copies.
127+ void eliminateSpillsOfReassignedVGPRs () const ;
128+
106129 bool run (MachineFunction &MF) const ;
107130};
108131
@@ -391,6 +414,138 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
391414 return MadeChange;
392415}
393416
417+ void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg (
418+ MachineInstr &SpillMI, int SpillFI, Register VReg) const {
419+ const DebugLoc &DL = SpillMI.getDebugLoc ();
420+ MachineBasicBlock &MBB = *SpillMI.getParent ();
421+ MachineInstr *NewCopy;
422+ if (SpillMI.mayStore ()) {
423+ NewCopy = BuildMI (MBB, SpillMI, DL, TII.get (TargetOpcode::COPY), VReg)
424+ .add (SpillMI.getOperand (0 ));
425+ } else {
426+ NewCopy = BuildMI (MBB, SpillMI, DL, TII.get (TargetOpcode::COPY))
427+ .add (SpillMI.getOperand (0 ))
428+ .addReg (VReg);
429+ }
430+
431+ LIS.ReplaceMachineInstrInMaps (SpillMI, *NewCopy);
432+ SpillMI.eraseFromParent ();
433+ }
434+
435+ void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses (
436+ ArrayRef<LiveInterval *> StackIntervals, SpillReferenceMap &Map) const {
437+
438+ SmallSet<int , 4 > NeededFrameIndexes;
439+ for (const LiveInterval *LI : StackIntervals)
440+ NeededFrameIndexes.insert (LI->reg ().stackSlotIndex ());
441+
442+ for (MachineBasicBlock &MBB : MF) {
443+ for (MachineInstr &MI : MBB) {
444+ for (MachineOperand &MO : MI.operands ()) {
445+ if (!MO.isFI () || !NeededFrameIndexes.count (MO.getIndex ()))
446+ continue ;
447+
448+ if (TII.isVGPRSpill (MI)) {
449+ SmallVector<MachineInstr *, 4 > &References = Map[MO.getIndex ()];
450+ References.push_back (&MI);
451+ break ;
452+ }
453+
454+ // Verify this was really a spill instruction, if it's not just ignore
455+ // all uses.
456+
457+ // TODO: This should probably be verifier enforced.
458+ NeededFrameIndexes.erase (MO.getIndex ());
459+ Map.erase (MO.getIndex ());
460+ }
461+ }
462+ }
463+ }
464+
465+ void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs () const {
466+ unsigned NumSlots = LSS.getNumIntervals ();
467+ if (NumSlots == 0 )
468+ return ;
469+
470+ MachineFrameInfo &MFI = MF.getFrameInfo ();
471+
472+ SmallVector<LiveInterval *, 32 > StackIntervals;
473+ StackIntervals.reserve (NumSlots);
474+
475+ for (auto &[Slot, LI] : LSS) {
476+ if (!MFI.isSpillSlotObjectIndex (Slot) || MFI.isDeadObjectIndex (Slot))
477+ continue ;
478+
479+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
480+ if (TRI.hasVGPRs (RC))
481+ StackIntervals.push_back (&LI);
482+ }
483+
484+ sort (StackIntervals, [](const LiveInterval *A, const LiveInterval *B) {
485+ // / Sort heaviest intervals first to prioritize their unspilling
486+ if (A->weight () > B->weight ())
487+ return true ;
488+
489+ if (A->getSize () > B->getSize ())
490+ return true ;
491+
492+ // Tie breaker by number to avoid need for stable sort
493+ return A->reg ().stackSlotIndex () < B->reg ().stackSlotIndex ();
494+ });
495+
496+ // FIXME: The APIs for dealing with the LiveInterval of a frame index are
497+ // cumbersome. LiveStacks owns its LiveIntervals which refer to stack
498+ // slots. We cannot use the usual LiveRegMatrix::assign and unassign on these,
499+ // and must create a substitute virtual register to do so. This makes
500+ // incremental updating here difficult; we need to actually perform the IR
501+ // mutation to get the new vreg references in place to compute the register
502+ // LiveInterval to perform an assignment to track the new interference
503+ // correctly, and we can't simply migrate the LiveInterval we already have.
504+ //
505+ // To avoid walking through the entire function for each index, pre-collect
506+ // all the instructions slot referencess.
507+
508+ DenseMap<int , SmallVector<MachineInstr *, 4 >> SpillSlotReferences;
509+ collectSpillIndexUses (StackIntervals, SpillSlotReferences);
510+
511+ for (LiveInterval *LI : StackIntervals) {
512+ int Slot = LI->reg ().stackSlotIndex ();
513+ auto SpillReferences = SpillSlotReferences.find (Slot);
514+ if (SpillReferences == SpillSlotReferences.end ())
515+ continue ;
516+
517+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
518+
519+ LLVM_DEBUG (dbgs () << " Trying to eliminate " << printReg (Slot, &TRI)
520+ << " by reassigning\n " );
521+
522+ ArrayRef<MCPhysReg> AllocOrder = RegClassInfo.getOrder (RC);
523+
524+ for (MCPhysReg PhysReg : AllocOrder) {
525+ if (LRM.checkInterference (*LI, PhysReg) != LiveRegMatrix::IK_Free)
526+ continue ;
527+
528+ LLVM_DEBUG (dbgs () << " Reassigning " << *LI << " to "
529+ << printReg (PhysReg, &TRI) << ' \n ' );
530+
531+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
532+ Register NewVReg = MRI.createVirtualRegister (RC);
533+
534+ for (MachineInstr *SpillMI : SpillReferences->second )
535+ replaceSpillWithCopyToVReg (*SpillMI, Slot, NewVReg);
536+
537+ // TODO: We should be able to transfer the information from the stack
538+ // slot's LiveInterval without recomputing from scratch with the
539+ // replacement vreg uses.
540+ LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval (NewVReg);
541+ VRM.grow ();
542+ LRM.assign (NewLI, PhysReg);
543+ MFI.RemoveStackObject (Slot);
544+ break ;
545+ }
546+ }
547+ }
548+
394549bool AMDGPURewriteAGPRCopyMFMAImpl::run (MachineFunction &MF) const {
395550 // This only applies on subtargets that have a configurable AGPR vs. VGPR
396551 // allocation.
@@ -417,6 +572,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
417572 MadeChange = true ;
418573 }
419574
575+ // If we've successfully rewritten some MFMAs, we've alleviated some VGPR
576+ // pressure. See if we can eliminate some spills now that those registers are
577+ // more available.
578+ if (MadeChange)
579+ eliminateSpillsOfReassignedVGPRs ();
580+
420581 return MadeChange;
421582}
422583
@@ -440,10 +601,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
440601 AU.addRequired <LiveIntervalsWrapperPass>();
441602 AU.addRequired <VirtRegMapWrapperLegacy>();
442603 AU.addRequired <LiveRegMatrixWrapperLegacy>();
604+ AU.addRequired <LiveStacksWrapperLegacy>();
443605
444606 AU.addPreserved <LiveIntervalsWrapperPass>();
445607 AU.addPreserved <VirtRegMapWrapperLegacy>();
446608 AU.addPreserved <LiveRegMatrixWrapperLegacy>();
609+ AU.addPreserved <LiveStacksWrapperLegacy>();
610+
447611 AU.setPreservesAll ();
448612 MachineFunctionPass::getAnalysisUsage (AU);
449613 }
@@ -456,6 +620,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
456620INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
457621INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
458622INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
623+ INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
459624INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
460625 " AMDGPU Rewrite AGPR-Copy-MFMA" , false , false )
461626
@@ -474,8 +639,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
474639 auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM ();
475640 auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM ();
476641 auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS ();
477-
478- AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, RegClassInfo);
642+ auto &LSS = getAnalysis<LiveStacksWrapperLegacy>(). getLS ();
643+ AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, LSS, RegClassInfo);
479644 return Impl.run (MF);
480645}
481646
@@ -485,13 +650,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
485650 VirtRegMap &VRM = MFAM.getResult <VirtRegMapAnalysis>(MF);
486651 LiveRegMatrix &LRM = MFAM.getResult <LiveRegMatrixAnalysis>(MF);
487652 LiveIntervals &LIS = MFAM.getResult <LiveIntervalsAnalysis>(MF);
653+ LiveStacks &LSS = MFAM.getResult <LiveStacksAnalysis>(MF);
488654 RegisterClassInfo RegClassInfo;
489655 RegClassInfo.runOnMachineFunction (MF);
490656
491- AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, RegClassInfo);
657+ AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, LSS, RegClassInfo);
492658 if (!Impl.run (MF))
493659 return PreservedAnalyses::all ();
494660 auto PA = getMachineFunctionPassPreservedAnalyses ();
495661 PA.preserveSet <CFGAnalyses>();
662+ PA.preserve <LiveStacksAnalysis>();
496663 return PA;
497664}
0 commit comments