29
29
#include " llvm/ADT/Statistic.h"
30
30
#include " llvm/CodeGen/LiveIntervals.h"
31
31
#include " llvm/CodeGen/LiveRegMatrix.h"
32
+ #include " llvm/CodeGen/LiveStacks.h"
33
+ #include " llvm/CodeGen/MachineFrameInfo.h"
32
34
#include " llvm/CodeGen/MachineFunctionPass.h"
33
35
#include " llvm/CodeGen/VirtRegMap.h"
34
36
#include " llvm/InitializePasses.h"
@@ -42,6 +44,9 @@ namespace {
42
44
STATISTIC (NumMFMAsRewrittenToAGPR,
43
45
" Number of MFMA instructions rewritten to use AGPR form" );
44
46
47
+ // / Map from spill slot frame index to list of instructions which reference it.
48
+ using SpillReferenceMap = DenseMap<int , SmallVector<MachineInstr *, 4 >>;
49
+
45
50
class AMDGPURewriteAGPRCopyMFMAImpl {
46
51
MachineFunction &MF;
47
52
const GCNSubtarget &ST;
@@ -51,6 +56,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
51
56
VirtRegMap &VRM;
52
57
LiveRegMatrix &LRM;
53
58
LiveIntervals &LIS;
59
+ LiveStacks &LSS;
54
60
const RegisterClassInfo &RegClassInfo;
55
61
56
62
bool attemptReassignmentsToAGPR (SmallSetVector<Register, 4 > &InterferingRegs,
@@ -59,10 +65,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
59
65
public:
60
66
AMDGPURewriteAGPRCopyMFMAImpl (MachineFunction &MF, VirtRegMap &VRM,
61
67
LiveRegMatrix &LRM, LiveIntervals &LIS,
68
+ LiveStacks &LSS,
62
69
const RegisterClassInfo &RegClassInfo)
63
70
: MF(MF), ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
64
71
TRI (*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
65
- LIS(LIS), RegClassInfo(RegClassInfo) {}
72
+ LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
66
73
67
74
bool isRewriteCandidate (const MachineInstr &MI) const {
68
75
return TII.isMAI (MI) && AMDGPU::getMFMASrcCVDstAGPROp (MI.getOpcode ()) != -1 ;
@@ -103,6 +110,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
103
110
104
111
bool tryFoldCopiesToAGPR (Register VReg, MCRegister AssignedAGPR) const ;
105
112
bool tryFoldCopiesFromAGPR (Register VReg, MCRegister AssignedAGPR) const ;
113
+
114
+ // / Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI
115
+ // / with a COPY to the replacement register value \p VReg.
116
+ void replaceSpillWithCopyToVReg (MachineInstr &SpillMI, int SpillFI,
117
+ Register VReg) const ;
118
+
119
+ // / Create a map from frame index to use instructions for spills. If a use of
120
+ // / the frame index does not consist only of spill instructions, it will not
121
+ // / be included in the map.
122
+ void collectSpillIndexUses (ArrayRef<LiveInterval *> StackIntervals,
123
+ SpillReferenceMap &Map) const ;
124
+
125
+ // / Attempt to unspill VGPRs by finding a free register and replacing the
126
+ // / spill instructions with copies.
127
+ void eliminateSpillsOfReassignedVGPRs () const ;
128
+
106
129
bool run (MachineFunction &MF) const ;
107
130
};
108
131
@@ -391,6 +414,138 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
391
414
return MadeChange;
392
415
}
393
416
417
+ void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg (
418
+ MachineInstr &SpillMI, int SpillFI, Register VReg) const {
419
+ const DebugLoc &DL = SpillMI.getDebugLoc ();
420
+ MachineBasicBlock &MBB = *SpillMI.getParent ();
421
+ MachineInstr *NewCopy;
422
+ if (SpillMI.mayStore ()) {
423
+ NewCopy = BuildMI (MBB, SpillMI, DL, TII.get (TargetOpcode::COPY), VReg)
424
+ .add (SpillMI.getOperand (0 ));
425
+ } else {
426
+ NewCopy = BuildMI (MBB, SpillMI, DL, TII.get (TargetOpcode::COPY))
427
+ .add (SpillMI.getOperand (0 ))
428
+ .addReg (VReg);
429
+ }
430
+
431
+ LIS.ReplaceMachineInstrInMaps (SpillMI, *NewCopy);
432
+ SpillMI.eraseFromParent ();
433
+ }
434
+
435
+ void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses (
436
+ ArrayRef<LiveInterval *> StackIntervals, SpillReferenceMap &Map) const {
437
+
438
+ SmallSet<int , 4 > NeededFrameIndexes;
439
+ for (const LiveInterval *LI : StackIntervals)
440
+ NeededFrameIndexes.insert (LI->reg ().stackSlotIndex ());
441
+
442
+ for (MachineBasicBlock &MBB : MF) {
443
+ for (MachineInstr &MI : MBB) {
444
+ for (MachineOperand &MO : MI.operands ()) {
445
+ if (!MO.isFI () || !NeededFrameIndexes.count (MO.getIndex ()))
446
+ continue ;
447
+
448
+ if (TII.isVGPRSpill (MI)) {
449
+ SmallVector<MachineInstr *, 4 > &References = Map[MO.getIndex ()];
450
+ References.push_back (&MI);
451
+ break ;
452
+ }
453
+
454
+ // Verify this was really a spill instruction, if it's not just ignore
455
+ // all uses.
456
+
457
+ // TODO: This should probably be verifier enforced.
458
+ NeededFrameIndexes.erase (MO.getIndex ());
459
+ Map.erase (MO.getIndex ());
460
+ }
461
+ }
462
+ }
463
+ }
464
+
465
+ void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs () const {
466
+ unsigned NumSlots = LSS.getNumIntervals ();
467
+ if (NumSlots == 0 )
468
+ return ;
469
+
470
+ MachineFrameInfo &MFI = MF.getFrameInfo ();
471
+
472
+ SmallVector<LiveInterval *, 32 > StackIntervals;
473
+ StackIntervals.reserve (NumSlots);
474
+
475
+ for (auto &[Slot, LI] : LSS) {
476
+ if (!MFI.isSpillSlotObjectIndex (Slot) || MFI.isDeadObjectIndex (Slot))
477
+ continue ;
478
+
479
+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
480
+ if (TRI.hasVGPRs (RC))
481
+ StackIntervals.push_back (&LI);
482
+ }
483
+
484
+ sort (StackIntervals, [](const LiveInterval *A, const LiveInterval *B) {
485
+ // / Sort heaviest intervals first to prioritize their unspilling
486
+ if (A->weight () > B->weight ())
487
+ return true ;
488
+
489
+ if (A->getSize () > B->getSize ())
490
+ return true ;
491
+
492
+ // Tie breaker by number to avoid need for stable sort
493
+ return A->reg ().stackSlotIndex () < B->reg ().stackSlotIndex ();
494
+ });
495
+
496
+ // FIXME: The APIs for dealing with the LiveInterval of a frame index are
497
+ // cumbersome. LiveStacks owns its LiveIntervals which refer to stack
498
+ // slots. We cannot use the usual LiveRegMatrix::assign and unassign on these,
499
+ // and must create a substitute virtual register to do so. This makes
500
+ // incremental updating here difficult; we need to actually perform the IR
501
+ // mutation to get the new vreg references in place to compute the register
502
+ // LiveInterval to perform an assignment to track the new interference
503
+ // correctly, and we can't simply migrate the LiveInterval we already have.
504
+ //
505
+ // To avoid walking through the entire function for each index, pre-collect
506
+ // all the instructions slot referencess.
507
+
508
+ DenseMap<int , SmallVector<MachineInstr *, 4 >> SpillSlotReferences;
509
+ collectSpillIndexUses (StackIntervals, SpillSlotReferences);
510
+
511
+ for (LiveInterval *LI : StackIntervals) {
512
+ int Slot = LI->reg ().stackSlotIndex ();
513
+ auto SpillReferences = SpillSlotReferences.find (Slot);
514
+ if (SpillReferences == SpillSlotReferences.end ())
515
+ continue ;
516
+
517
+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
518
+
519
+ LLVM_DEBUG (dbgs () << " Trying to eliminate " << printReg (Slot, &TRI)
520
+ << " by reassigning\n " );
521
+
522
+ ArrayRef<MCPhysReg> AllocOrder = RegClassInfo.getOrder (RC);
523
+
524
+ for (MCPhysReg PhysReg : AllocOrder) {
525
+ if (LRM.checkInterference (*LI, PhysReg) != LiveRegMatrix::IK_Free)
526
+ continue ;
527
+
528
+ LLVM_DEBUG (dbgs () << " Reassigning " << *LI << " to "
529
+ << printReg (PhysReg, &TRI) << ' \n ' );
530
+
531
+ const TargetRegisterClass *RC = LSS.getIntervalRegClass (Slot);
532
+ Register NewVReg = MRI.createVirtualRegister (RC);
533
+
534
+ for (MachineInstr *SpillMI : SpillReferences->second )
535
+ replaceSpillWithCopyToVReg (*SpillMI, Slot, NewVReg);
536
+
537
+ // TODO: We should be able to transfer the information from the stack
538
+ // slot's LiveInterval without recomputing from scratch with the
539
+ // replacement vreg uses.
540
+ LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval (NewVReg);
541
+ VRM.grow ();
542
+ LRM.assign (NewLI, PhysReg);
543
+ MFI.RemoveStackObject (Slot);
544
+ break ;
545
+ }
546
+ }
547
+ }
548
+
394
549
bool AMDGPURewriteAGPRCopyMFMAImpl::run (MachineFunction &MF) const {
395
550
// This only applies on subtargets that have a configurable AGPR vs. VGPR
396
551
// allocation.
@@ -417,6 +572,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
417
572
MadeChange = true ;
418
573
}
419
574
575
+ // If we've successfully rewritten some MFMAs, we've alleviated some VGPR
576
+ // pressure. See if we can eliminate some spills now that those registers are
577
+ // more available.
578
+ if (MadeChange)
579
+ eliminateSpillsOfReassignedVGPRs ();
580
+
420
581
return MadeChange;
421
582
}
422
583
@@ -440,10 +601,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
440
601
AU.addRequired <LiveIntervalsWrapperPass>();
441
602
AU.addRequired <VirtRegMapWrapperLegacy>();
442
603
AU.addRequired <LiveRegMatrixWrapperLegacy>();
604
+ AU.addRequired <LiveStacksWrapperLegacy>();
443
605
444
606
AU.addPreserved <LiveIntervalsWrapperPass>();
445
607
AU.addPreserved <VirtRegMapWrapperLegacy>();
446
608
AU.addPreserved <LiveRegMatrixWrapperLegacy>();
609
+ AU.addPreserved <LiveStacksWrapperLegacy>();
610
+
447
611
AU.setPreservesAll ();
448
612
MachineFunctionPass::getAnalysisUsage (AU);
449
613
}
@@ -456,6 +620,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
456
620
INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
457
621
INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
458
622
INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
623
+ INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy)
459
624
INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
460
625
" AMDGPU Rewrite AGPR-Copy-MFMA" , false , false )
461
626
@@ -474,8 +639,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
474
639
auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM ();
475
640
auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM ();
476
641
auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS ();
477
-
478
- AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, RegClassInfo);
642
+ auto &LSS = getAnalysis<LiveStacksWrapperLegacy>(). getLS ();
643
+ AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, LSS, RegClassInfo);
479
644
return Impl.run (MF);
480
645
}
481
646
@@ -485,13 +650,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
485
650
VirtRegMap &VRM = MFAM.getResult <VirtRegMapAnalysis>(MF);
486
651
LiveRegMatrix &LRM = MFAM.getResult <LiveRegMatrixAnalysis>(MF);
487
652
LiveIntervals &LIS = MFAM.getResult <LiveIntervalsAnalysis>(MF);
653
+ LiveStacks &LSS = MFAM.getResult <LiveStacksAnalysis>(MF);
488
654
RegisterClassInfo RegClassInfo;
489
655
RegClassInfo.runOnMachineFunction (MF);
490
656
491
- AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, RegClassInfo);
657
+ AMDGPURewriteAGPRCopyMFMAImpl Impl (MF, VRM, LRM, LIS, LSS, RegClassInfo);
492
658
if (!Impl.run (MF))
493
659
return PreservedAnalyses::all ();
494
660
auto PA = getMachineFunctionPassPreservedAnalyses ();
495
661
PA.preserveSet <CFGAnalyses>();
662
+ PA.preserve <LiveStacksAnalysis>();
496
663
return PA;
497
664
}
0 commit comments