Skip to content

Conversation

@tlinthic
Copy link

@tlinthic tlinthic commented Dec 2, 2025

This pull request is an update of Jeff Byrne's PR. Additionally, all unresolved comments from the original PR have been addressed.

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Tony Linthicum (tlinthic)

Changes

This pull request is an update of Jeff Byrne's PR. Additionally, all unresolved comments from the original PR have been addressed.


Patch is 390.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170335.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/MachineInstrBuilder.h (+9)
  • (modified) llvm/lib/Target/AMDGPU/GCNRegPressure.h (+31)
  • (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp (+637)
  • (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.h (+63-5)
  • (added) llvm/test/CodeGen/AMDGPU/sched_mfma_rewrite_copies.mir (+5591)
  • (added) llvm/test/CodeGen/AMDGPU/sched_mfma_rewrite_cost.mir (+524)
diff --git a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
index caeb430d6fd1c..8c16b06bce458 100644
--- a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
+++ b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h
@@ -375,6 +375,15 @@ class MachineInstrBuilder {
     return *this;
   }
 
+  /// Inserts the newly-built instruction after the given position in the
+  /// given MachineBasicBlock.
+  const MachineInstrBuilder &insertAfter(MachineInstr *MInstr) const {
+    MachineBasicBlock *MBB = MInstr->getParent();
+    MachineBasicBlock::iterator I = MInstr->getIterator();
+    MBB->insertAfter(I, MI);
+    return *this;
+  }
+
   bool constrainAllUses(const TargetInstrInfo &TII,
                         const TargetRegisterInfo &TRI,
                         const RegisterBankInfo &RBI) const {
diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
index f9d3ce039092e..d13d1ddd9c0eb 100644
--- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h
+++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h
@@ -102,6 +102,37 @@ struct GCNRegPressure {
                                                 DynamicVGPRBlockSize));
   }
 
+  unsigned getVGPRSpills(MachineFunction &MF) {
+    const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+    if (!ST.hasGFX90AInsts())
+      return 0;
+
+    auto MaxVectorRegs = ST.getMaxNumVectorRegs(MF.getFunction());
+    unsigned ArchVGPRThreshold = MaxVectorRegs.first;
+    unsigned AGPRThreshold = MaxVectorRegs.second;
+
+    unsigned ArchPressure = getArchVGPRNum();
+    unsigned AGPRPressure = getAGPRNum();
+
+    unsigned ArchSpill = ArchPressure > ArchVGPRThreshold
+                             ? (ArchPressure - ArchVGPRThreshold)
+                             : 0;
+    unsigned AGPRSpill =
+        AGPRPressure > AGPRThreshold ? (AGPRPressure - AGPRThreshold) : 0;
+
+    unsigned UnifiedSpill = 0;
+
+    if (ST.hasGFX90AInsts()) {
+      unsigned CombinedThreshold = ST.getMaxNumVGPRs(MF);
+      unsigned UnifiedPressure = getVGPRNum(true);
+      UnifiedSpill = UnifiedPressure > CombinedThreshold
+                         ? (UnifiedPressure - CombinedThreshold)
+                         : 0;
+    }
+
+    return std::max(UnifiedSpill, (ArchSpill + AGPRSpill));
+  }
+
   void inc(unsigned Reg,
            LaneBitmask PrevMask,
            LaneBitmask NewMask,
diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
index c8ce3aab3f303..0773789c0ace2 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
@@ -30,6 +30,7 @@
 #include "Utils/AMDGPUBaseInfo.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/CodeGen/CalcSpillWeights.h"
+#include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/RegisterClassInfo.h"
 #include "llvm/MC/LaneBitmask.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -690,6 +691,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
     const MachineSchedContext *C, bool IsLegacyScheduler)
     : GCNSchedStrategy(C) {
   SchedStages.push_back(GCNSchedStageID::OccInitialSchedule);
+  SchedStages.push_back(GCNSchedStageID::RewriteSchedule);
   SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
   SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
   SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
@@ -946,6 +948,8 @@ GCNScheduleDAGMILive::createSchedStage(GCNSchedStageID SchedStageID) {
   switch (SchedStageID) {
   case GCNSchedStageID::OccInitialSchedule:
     return std::make_unique<OccInitialScheduleStage>(SchedStageID, *this);
+  case GCNSchedStageID::RewriteSchedule:
+    return std::make_unique<RewriteScheduleStage>(SchedStageID, *this);
   case GCNSchedStageID::UnclusteredHighRPReschedule:
     return std::make_unique<UnclusteredHighRPStage>(SchedStageID, *this);
   case GCNSchedStageID::ClusteredLowOccupancyReschedule:
@@ -1183,6 +1187,9 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const GCNSchedStageID &StageID) {
   case GCNSchedStageID::OccInitialSchedule:
     OS << "Max Occupancy Initial Schedule";
     break;
+  case GCNSchedStageID::RewriteSchedule:
+    OS << "Instruction Rewriting Reschedule";
+    break;
   case GCNSchedStageID::UnclusteredHighRPReschedule:
     OS << "Unclustered High Register Pressure Reschedule";
     break;
@@ -1216,6 +1223,110 @@ bool GCNSchedStage::initGCNSchedStage() {
   return true;
 }
 
+void RewriteScheduleStage::findReachingDefs(
+    MachineOperand &UseMO, LiveIntervals *LIS,
+    SmallVectorImpl<SlotIndex> &DefIdxs) {
+  assert(UseMO.isReg());
+  MachineInstr *UseMI = UseMO.getParent();
+  LiveInterval &UseLI = LIS->getInterval(UseMO.getReg());
+  VNInfo *VNI = UseLI.getVNInfoAt(LIS->getInstructionIndex(*UseMI));
+
+  SlotIndex DefMBBStart = LIS->getMBBStartIdx(LIS->getMBBFromIndex(VNI->def));
+
+  // If the def is in the block, then it must be the only reaching def.
+  if (DefMBBStart != VNI->def) {
+    DefIdxs.push_back(VNI->def);
+    return;
+  }
+
+  SmallPtrSet<MachineBasicBlock *, 8> Visited;
+  SmallVector<MachineBasicBlock *, 8> Worklist;
+
+  Visited.insert(UseMI->getParent());
+
+  // Mark the predecessor blocks for traversal
+  for (auto *PredMBB : UseMI->getParent()->predecessors()) {
+    Worklist.push_back(PredMBB);
+    Visited.insert(PredMBB);
+  }
+
+  while (!Worklist.empty()) {
+    MachineBasicBlock *CurrMBB = Worklist.pop_back_val();
+
+    SlotIndex CurrMBBEnd = LIS->getMBBEndIdx(CurrMBB);
+    VNInfo *VNI = UseLI.getVNInfoAt(CurrMBBEnd.getPrevSlot());
+
+    MachineBasicBlock *DefMBB = LIS->getMBBFromIndex(VNI->def);
+    SlotIndex DefMBBStart = LIS->getMBBStartIdx(DefMBB);
+
+    // If there is a def in this block, then add it to the list. This is the
+    // reaching def of this path.
+    if (DefMBBStart != VNI->def) {
+      DefIdxs.push_back(VNI->def);
+      continue;
+    }
+
+    for (auto *PredMBB : DefMBB->predecessors()) {
+      if (Visited.insert(PredMBB).second)
+        Worklist.push_back(PredMBB);
+    }
+  }
+}
+
+void RewriteScheduleStage::findReachingUses(
+    MachineInstr *DefMI, LiveIntervals *LIS,
+    SmallVectorImpl<MachineOperand *> &ReachingUses) {
+  SlotIndex DefIdx = LIS->getInstructionIndex(*DefMI);
+  for (auto &UseMO :
+       DAG.MRI.use_nodbg_operands(DefMI->getOperand(0).getReg())) {
+    SmallVector<SlotIndex, 8> ReachingDefIndexes;
+    findReachingDefs(UseMO, LIS, ReachingDefIndexes);
+
+    // If we find a use that contains this DefMI in its reachingDefs, then it is
+    // a reaching use.
+    if (any_of(ReachingDefIndexes, [DefIdx](SlotIndex RDIdx) {
+          return SlotIndex::isSameInstr(RDIdx, DefIdx);
+        }))
+      ReachingUses.push_back(&UseMO);
+  }
+}
+
+bool RewriteScheduleStage::initGCNSchedStage() {
+  const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
+  if (!ST.hasGFX90AInsts() || MFI.getMinWavesPerEU() > 1)
+    return false;
+
+  RegionsWithExcessArchVGPR.resize(DAG.Regions.size());
+  RegionsWithExcessArchVGPR.reset();
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    auto PressureBefore = DAG.Pressure[Region];
+    if (PressureBefore.getArchVGPRNum() > ST.getAddressableNumArchVGPRs())
+      RegionsWithExcessArchVGPR[Region] = true;
+  }
+
+  if (RegionsWithExcessArchVGPR.none())
+    return false;
+
+  TII = ST.getInstrInfo();
+  SRI = ST.getRegisterInfo();
+
+  std::vector<std::pair<MachineInstr *, unsigned>> RewriteCands;
+  DenseMap<MachineBasicBlock *, std::set<Register>> CopyForUse;
+  SmallPtrSet<MachineInstr *, 8> CopyForDef;
+
+  if (!initHeuristics(RewriteCands, CopyForUse, CopyForDef))
+    return false;
+
+  int64_t Cost = getRewriteCost(RewriteCands, CopyForUse, CopyForDef);
+
+  // If we haven't found the beneficial conditions, prefer the VGPR form which
+  // may result in less cross RC copies.
+  if (Cost > 0)
+    return false;
+
+  return rewrite(RewriteCands);
+}
+
 bool UnclusteredHighRPStage::initGCNSchedStage() {
   if (DisableUnclusterHighRP)
     return false;
@@ -1837,6 +1948,532 @@ void GCNSchedStage::revertScheduling() {
   DAG.Regions[RegionIdx] = std::pair(DAG.RegionBegin, DAG.RegionEnd);
 }
 
+bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {
+
+  if (!static_cast<const SIInstrInfo *>(DAG.TII)->isMAI(*MI))
+    return false;
+  return AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode()) != -1;
+}
+
+bool RewriteScheduleStage::initHeuristics(
+    std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+    DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+    SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+  // Prepare for the heuristics
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (!isRewriteCandidate(&MI))
+        continue;
+
+      int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode());
+      assert(ReplacementOp != -1);
+
+      RewriteCands.push_back({&MI, MI.getOpcode()});
+      MI.setDesc(TII->get(ReplacementOp));
+
+      MachineOperand *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2);
+      if (Src2->isReg()) {
+        SmallVector<SlotIndex, 8> Src2ReachingDefs;
+        findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+
+        // For any definition of the src2 register which is non-MFMA, we
+        // insert a copy.
+        for (SlotIndex RDIdx : Src2ReachingDefs) {
+          MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
+          if (!TII->isMAI(*RD))
+            CopyForDef.insert(RD);
+        }
+      }
+
+        MachineOperand &Dst = MI.getOperand(0);
+        SmallVector<MachineOperand *, 8> DstReachingUses;
+
+        findReachingUses(&MI, DAG.LIS, DstReachingUses);
+
+        for (MachineOperand *RUOp : DstReachingUses) {
+          if (TII->isMAI(*RUOp->getParent()))
+            continue;
+
+          // For any user of the result of the MFMA which is not an MFMA, we
+          // insert a copy. For a given register, we will only insert one copy
+          // per user block.
+          CopyForUse[RUOp->getParent()->getParent()].insert(RUOp->getReg());
+
+          SmallVector<SlotIndex, 8> DstUsesReachingDefs;
+          findReachingDefs(*RUOp, DAG.LIS, DstUsesReachingDefs);
+
+          for (auto RDIndex : DstUsesReachingDefs) {
+            MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+            if (TII->isMAI(*RD))
+              continue;
+
+            // For any definition of the user of the MFMA which is not an MFMA,
+            // we insert a copy. We do this to transform all the reaching defs
+            // of this use to AGPR. By doing this, we can insert a copy from
+            // AGPR to VGPR at the user rather than after the MFMA.
+            CopyForDef.insert(RD);
+          }
+        }
+
+        // Do the rewrite to allow for updated RP calculation.
+        const TargetRegisterClass *VGPRRC = DAG.MRI.getRegClass(Dst.getReg());
+        const TargetRegisterClass *AGPRRC = SRI->getEquivalentAGPRClass(VGPRRC);
+        DAG.MRI.setRegClass(Dst.getReg(), AGPRRC);
+        if (Src2->isReg())
+          DAG.MRI.setRegClass(Src2->getReg(), AGPRRC);
+    }
+  }
+
+  return true;
+}
+
+int64_t RewriteScheduleStage::getRewriteCost(
+    const std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands,
+    const DenseMap<MachineBasicBlock *, std::set<Register>> &CopyForUse,
+    const SmallPtrSetImpl<MachineInstr *> &CopyForDef) {
+  MachineBranchProbabilityInfo MBPI;
+  MachineBlockFrequencyInfo MBFI;
+
+  MBFI.calculate(MF, MBPI, *DAG.MLI);
+  int64_t BestSpillCost = 0;
+  int64_t Cost = 0;
+
+  uint64_t EntryFreq = MBFI.getEntryFreq().getFrequency();
+
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    if (!RegionsWithExcessArchVGPR[Region])
+      continue;
+
+    GCNRegPressure &PressureBefore = DAG.Pressure[Region];
+    unsigned SpillCostBefore = PressureBefore.getVGPRSpills(MF);
+
+    // For the cases we care about (i.e. ArchVGPR usage is greater than the
+    // addressable limit), rewriting alone should bring pressure to manageable
+    // level. If we find any such region, then the rewrite is potentially
+    // beneficial.
+    GCNRegPressure PressureAfter = DAG.getRealRegPressure(Region);
+    unsigned SpillCostAfter = PressureAfter.getVGPRSpills(MF);
+
+    uint64_t BlockFreq =
+        MBFI.getBlockFreq(DAG.Regions[Region].first->getParent())
+            .getFrequency();
+
+    bool RelativeFreqIsDenom = EntryFreq > BlockFreq;
+    uint64_t RelativeFreq = EntryFreq && BlockFreq
+                                ? (RelativeFreqIsDenom ? EntryFreq / BlockFreq
+                                                       : BlockFreq / EntryFreq)
+                                : 1;
+
+    // This assumes perfect spilling / splitting -- using one spill / copy
+    // instruction and one restoreFrom / copy for each excess register,
+    int64_t SpillCost = ((int)SpillCostAfter - (int)SpillCostBefore) * 2;
+
+    // Also account for the block frequency.
+    if (RelativeFreqIsDenom)
+      SpillCost /= (int64_t)RelativeFreq;
+    else
+      SpillCost *= (int64_t)RelativeFreq;
+
+    // If we have increased spilling in any block, just bail.
+    if (SpillCost > 0)
+      return SpillCost;
+
+    if (SpillCost < BestSpillCost)
+      BestSpillCost = SpillCost;
+  }
+
+  // Set the cost to the largest decrease in spill cost in order to not double
+  // count spill reductions.
+  Cost = BestSpillCost;
+
+  assert(Cost <= 0);
+
+  unsigned CopyCost = 0;
+
+  // For each CopyForDef, increase the cost by the register size while
+  // accounting for block frequency.
+  for (auto *DefMI : CopyForDef) {
+    auto DefReg = DefMI->getOperand(0).getReg();
+    uint64_t DefFreq =
+        EntryFreq
+            ? MBFI.getBlockFreq(DefMI->getParent()).getFrequency() / EntryFreq
+            : 1;
+
+    unsigned RegSize = DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(DefReg));
+    unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+    CopyCost += NumRegs * DefFreq;
+  }
+
+  // Account for CopyForUse copies in each block that the register is used.
+  for (auto &[UseBlock, UseRegs] : CopyForUse) {
+    uint64_t UseFreq =
+        EntryFreq ? MBFI.getBlockFreq(UseBlock).getFrequency() / EntryFreq : 1;
+
+    for (auto UseReg : UseRegs) {
+      unsigned RegSize =
+          DAG.TRI->getRegSizeInBits(*DAG.MRI.getRegClass(UseReg));
+      unsigned NumRegs = std::max(RegSize / 32, (unsigned)1);
+      CopyCost += NumRegs * UseFreq;
+    }
+  }
+
+  Cost += CopyCost;
+
+  // Reset to the vgpr form. We must do rewriting after copy-insertion, as some
+  // defs of the register may require VGPR.
+  for (auto &[MI, OriginalOpcode] : RewriteCands) {
+    assert(TII->isMAI(*MI));
+    const TargetRegisterClass *AGPRRC =
+        DAG.MRI.getRegClass(MI->getOperand(0).getReg());
+    const TargetRegisterClass *VGPRRC = SRI->getEquivalentVGPRClass(AGPRRC);
+
+    MachineOperand *Src2 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2);
+    assert(Src2);
+
+    if (Src2->isReg())
+      DAG.MRI.setRegClass(Src2->getReg(), VGPRRC);
+    DAG.MRI.setRegClass(MI->getOperand(0).getReg(), VGPRRC);
+    MI->setDesc(TII->get(OriginalOpcode));
+  }
+
+  return Cost;
+}
+
+bool RewriteScheduleStage::rewrite(
+    const std::vector<std::pair<MachineInstr *, unsigned>> &RewriteCands) {
+  DenseMap<MachineInstr *, unsigned> FirstMIToRegion;
+  DenseMap<MachineInstr *, unsigned> LastMIToRegion;
+
+  for (unsigned Region = 0; Region < DAG.Regions.size(); Region++) {
+    auto Entry = DAG.Regions[Region];
+    if (Entry.first == Entry.second)
+      continue;
+
+    FirstMIToRegion[&*Entry.first] = Region;
+    if (Entry.second != Entry.first->getParent()->end())
+      LastMIToRegion[&*Entry.second] = Region;
+  }
+
+  // Rewrite the MFMAs to AGPR, and insert any copies as needed.
+  // The general assumption of the algorithm (and the previous cost calculation)
+  // is that it is better to insert the copies in the MBB of the def of the src2
+  // operands, and in the MBB of the user of the dest operands. This is based on
+  // the assumption that the MFMAs are likely to appear in loop bodies, while
+  // the src2 and dest operands are live-in / live-out of the loop. Due to this
+  // design, the algorithm for finding copy insertion points is more
+  // complicated.
+  //
+  // There are three main cases to handle: 1. the reaching defs of the src2
+  // operands, 2. the reaching uses of the dst operands, and 3. the reaching
+  // defs of the reaching uses of the dst operand.
+  //
+  // In the first case, we simply insert copies after each of the reaching
+  // definitions. In the second case, we collect all the uses of a given dest
+  // and organize them by MBB. Then, we insert 1 copy for each MBB before the
+  // earliest use. Since the use may have multiple reaching defs, and since we
+  // want to replace the register it is using with the result of the copy, we
+  // must handle case 3. In the third case, we simply insert a copy after each
+  // of the reaching defs to connect to the copy of the reaching uses of the dst
+  // reg. This allows us to avoid inserting copies next to the MFMAs.
+  //
+  // While inserting the copies, we maintain a map of operands which will use
+  // different regs (i.e. the result of the copies). For example, a case 1 src2
+  // operand will use the register result of the copies after the reaching defs,
+  // as opposed to the original register. Now that we have completed our copy
+  // analysis and placement, we can bulk update the registers. We do this
+  // separately as to avoid complicating the reachingDef and reachingUse
+  // queries.
+  //
+  // While inserting the copies, we also maintain a list or registers which we
+  // will want to reclassify as AGPR. After doing the copy insertion and the
+  // register replacement, we can finally do the reclassification. This uses the
+  // redef map, as the registers we are interested in reclassifying may be
+  // replaced by the result of a copy. We must do this after the copy analysis
+  // and placement as we must have an accurate redef map -- otherwise we may end
+  // up creating illegal instructions.
+
+  // The original registers of the MFMA that need to be reclassified as AGPR.
+  std::set<Register> RewriteRegs;
+  // The map of an original register in the MFMA to a new register (result of a
+  // copy) that it should be replaced with.
+  DenseMap<Register, Register> RedefMap;
+  // The map of the original MFMA registers to the relevant MFMA operands.
+  DenseMap<Register, std::set<MachineOperand *>> ReplaceMap;
+  // The map of reaching defs for a given register -- to avoid duplicate copies.
+  DenseMap<Register, SmallPtrSet<MachineInstr *, 8>> ReachingDefCopyMap;
+  // The map of reaching uses for a given register by basic block -- to avoid
+  // duplicate copies and to calculate per MBB insert pts.
+  DenseMap<unsigned, DenseMap<Register, SmallPtrSet<MachineOperand *, 8>>>
+      ReachingUseTracker;
+
+  for (auto &[MI, OriginalOpcode] : RewriteCands) {
+
+    int ReplacementOp = AMDGPU::getMFMASrcCVDstAGPROp(MI->getOpcode());
+    if (ReplacementOp == -1)
+      continue;
+    MI->setDesc(TII->get(ReplacementOp));
+
+    // Case 1: insert copies for the reaching defs of the Src2Reg.
+    MachineOperand *Src2 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2);
+
+    if (Src2->isReg()) {
+      Register Src2Reg = Src2->getReg();
+      if (!Src2Reg.isVirtual())
+        return false;
+
+      Register MappedReg = Src2->getReg();
+      SmallVector<SlotIndex, 8> Src2ReachingDefs;
+      findReachingDefs(*Src2, DAG.LIS, Src2ReachingDefs);
+      SmallVector<MachineInstr *, 8> Src2DefsReplace;
+
+      for (auto RDIndex : Src2ReachingDefs) {
+        MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIndex);
+        if (TII->isMAI(*RD))
+          continue;
+
+        // If there is a non mai reaching def, then we need a copy.
+        if (find(Src2DefsReplace, RD) == Src2DefsReplace.end())
+          Src2DefsReplace.push_back(RD);
+      }
+
+      if (!Src2DefsReplace.empty()) {
+        if (RedefMap.contains(Src2Reg)) {
+          MappedReg = RedefMap[Src2Reg];
+        } else {
+          assert(!ReachingDefCopyMap.contains(Src2Reg));
+          const TargetRegisterClass *Src2RC = DAG.MRI.getRegClass(Src2Reg);
+          const TargetRegisterClass *V...
[truncated]

Copy link
Contributor

@shiltian shiltian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just fly by with some format comments

}

bool RewriteScheduleStage::isRewriteCandidate(MachineInstr *MI) const {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicating the core rewrite logic from AMDGPURewriteAGPRCopyMFMA is unfortunate. This should eventually cover more cases, like #168983

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting that there should be a follow on patch that supports the case referenced above?

if (Src2->isReg()) {
Register Src2Reg = Src2->getReg();
if (!Src2Reg.isVirtual())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The instruction is already broken above by the setDesc if this fails

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setDesc removed.


Cost += CopyCost;

// Reset to the vgpr form. We must do rewriting after copy-insertion, as some
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? This should be a one way process from VGPR to AGPR form?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the code and couldn't find a need for it. I've removed it and it still seems to work as before. More extensive testing should be done before this is merged, but that was my plan anyway.

if (Src2->isReg())
DAG.MRI.setRegClass(Src2->getReg(), VGPRRC);
DAG.MRI.setRegClass(MI->getOperand(0).getReg(), VGPRRC);
MI->setDesc(TII->get(OriginalOpcode));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid making the modification in the first place instead of roll back?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. See above.

Comment on lines 2031 to 2034
MachineBranchProbabilityInfo MBPI;
MachineBlockFrequencyInfo MBFI;

MBFI.calculate(MF, MBPI, *DAG.MLI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should come from the pass manager instead of freshly computing them here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MachineScheduler pass does not appear to require MachineBlockFrequencyInfo. We would either have to make it do so, which seems like over kill since it's only for AMDGPU, or do something in AMDGPU's pass configuration to ensure that the analysis is available. What are you thinking here? Am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming these are already computed and available, so requiring them for the machine scheduler shouldn't cost anything. I would just do that and see if it actually changes the pass pipeline.

I'm also iffy on this being a scheduling stage instead of a standalone pass

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the dependency and the plumbing to get the analysis back to our code and removed the explicit calculation there. I see an additional MachineBlockFrequencyInfo pass when I do that so there apparently isn't a preserved analysis at the time we run the Pre-RA scheduler.

You say you are "iffy" about it being a scheduler stage vs a standalone pass. Personally, I can see arguments for either approach. I'd prefer to keep it as is, but if you feel strongly that it should be a standalone pass then we can go that route.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not alive at this point, it's likely retained after? I don't think any of these mid-regalloc passes are changing control flow anywhere, so it's probably quite easy to mark this as preserved in whatever is invalidating it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason I can't respond to your comment below regarding making the analysis preserved where it is being invalidated. I've made that change. I had to change 3 passes: PHI elimination, unreachable block elimination and SI lower control flow. Two of those passes do touch the CFG, but not in a way that I think would invalidate frequency data. They already preserved the dominator tree analysis, so this seems safe. What do you think?

@tlinthic tlinthic force-pushed the RewriteScheduleStage branch from 70d0a0a to 5745907 Compare December 5, 2025 17:44
void RewriteScheduleStage::findReachingDefs(
MachineOperand &UseMO, LiveIntervals *LIS,
SmallVectorImpl<SlotIndex> &DefIdxs) {
assert(UseMO.isReg());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(UseMO.isReg());

getReg will just fail the same assert anyway

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1240 to 1243
SmallPtrSet<MachineBasicBlock *, 8> Visited;
SmallVector<MachineBasicBlock *, 8> Worklist;

Visited.insert(UseMI->getParent());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SmallPtrSet<MachineBasicBlock *, 8> Visited;
SmallVector<MachineBasicBlock *, 8> Worklist;
Visited.insert(UseMI->getParent());
SmallPtrSet<MachineBasicBlock *, 8> Visited = {UseMI->getParent()};
SmallVector<MachineBasicBlock *, 8> Worklist;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +1247 to +1248
Worklist.push_back(PredMBB);
Visited.insert(PredMBB);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should be using SmallSetVector

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, missed this one yesterday. For both or just Visited?

MachineInstr *DefMI, LiveIntervals *LIS,
SmallVectorImpl<MachineOperand *> &ReachingUses) {
SlotIndex DefIdx = LIS->getInstructionIndex(*DefMI);
for (auto &UseMO :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &UseMO :
for (MachineOperand &UseMO :

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

switch (SchedStageID) {
case GCNSchedStageID::OccInitialSchedule:
return std::make_unique<OccInitialScheduleStage>(SchedStageID, *this);
case GCNSchedStageID::RewriteSchedule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is too general for what this is doing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 2031 to 2034
MachineBranchProbabilityInfo MBPI;
MachineBlockFrequencyInfo MBFI;

MBFI.calculate(MF, MBPI, *DAG.MLI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming these are already computed and available, so requiring them for the machine scheduler shouldn't cost anything. I would just do that and see if it actually changes the pass pipeline.

I'm also iffy on this being a scheduling stage instead of a standalone pass

@tlinthic tlinthic force-pushed the RewriteScheduleStage branch from 5745907 to d041e80 Compare December 8, 2025 21:00
// insert a copy.
for (SlotIndex RDIdx : Src2ReachingDefs) {
MachineInstr *RD = DAG.LIS->getInstructionFromIndex(RDIdx);
if (!TII->isMAI(*RD))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be checking isRewriteCandidate instead of hardcoding the MAI case?

Copy link
Author

@tlinthic tlinthic Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It breaks things with that change (rewrites don't occur). Your suggestion makes sense to me, so I'm not sure why. I'm going to leave this discussion open and circle back to it when I make the change to avoid changing opcodes and then reverting that if a rewrite is not performed. That should be in the next couple of days.

EDIT: After removing the early setDesc and subsequent reset, this change now works and is included with the last update.

Comment on lines 1960 to 1961
for (auto &MBB : MF) {
for (auto &MI : MBB) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &MBB : MF) {
for (auto &MI : MBB) {
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 2031 to 2034
MachineBranchProbabilityInfo MBPI;
MachineBlockFrequencyInfo MBFI;

MBFI.calculate(MF, MBPI, *DAG.MLI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not alive at this point, it's likely retained after? I don't think any of these mid-regalloc passes are changing control flow anywhere, so it's probably quite easy to mark this as preserved in whatever is invalidating it

Tony Linthicum added 4 commits December 9, 2025 08:49
…I elimination,

unreachable block elimination and SI lower control flow (AMDGPU) passes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants