Skip to content

Conversation

@kerbowa
Copy link
Member

@kerbowa kerbowa commented Nov 26, 2025

This patch introduces scaffolding for a new machine instruction
scheduling strategy optimized for machine learning workloads.

Enable the ML scheduler automatically when functions have the
"amdgpu-workload-type"="ml" attribute.

Copy link
Member Author

kerbowa commented Nov 26, 2025

@github-actions
Copy link

github-actions bot commented Nov 26, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch introduces scaffolding for a new machine instruction
scheduling strategy optimized for machine learning workloads.

Enable the ML scheduler automatically when functions have the
"amdgpu-workload-type"="ml" attribute.
@kerbowa kerbowa force-pushed the users/kerbowa/ml-sched-strategy-scaffold branch from b639aa2 to 4d99b6d Compare December 2, 2025 16:41
@kerbowa kerbowa marked this pull request as ready for review December 2, 2025 18:35
@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Austin Kerbow (kerbowa)

Changes

This patch introduces scaffolding for a new machine instruction
scheduling strategy optimized for machine learning workloads.

Enable the ML scheduler automatically when functions have the
"amdgpu-workload-type"="ml" attribute.


Patch is 26.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169616.diff

8 Files Affected:

  • (added) llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp (+153)
  • (added) llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h (+43)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp (+22)
  • (modified) llvm/lib/Target/AMDGPU/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp (+10-10)
  • (modified) llvm/lib/Target/AMDGPU/GCNSchedStrategy.h (+11)
  • (modified) llvm/lib/Target/AMDGPU/GCNSubtarget.cpp (+8)
  • (added) llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir (+130)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp
new file mode 100644
index 0000000000000..8c68223c0a492
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.cpp
@@ -0,0 +1,153 @@
+//===-- AMDGPUMLSchedStrategy.cpp - ML-focused Scheduler Strategy ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// ML-focused scheduling strategy for AMDGPU.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPUMLSchedStrategy.h"
+
+using namespace llvm;
+
+AMDGPUMLSchedStrategy::AMDGPUMLSchedStrategy(const MachineSchedContext *C)
+    : GCNSchedStrategy(C) {
+  SchedStages.push_back(GCNSchedStageID::ILPInitialSchedule);
+  SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
+  // Use more accurate GCN pressure trackers.
+  UseGCNTrackers = true;
+}
+
+void AMDGPUMLSchedStrategy::initialize(ScheduleDAGMI *DAG) {
+  // ML scheduling strategy is only done top-down to support new resource
+  // balancing heuristics.
+  RegionPolicy.OnlyTopDown = true;
+  RegionPolicy.OnlyBottomUp = false;
+
+  GCNSchedStrategy::initialize(DAG);
+}
+
+bool AMDGPUMLSchedStrategy::tryCandidate(SchedCandidate &Cand,
+                                         SchedCandidate &TryCand,
+                                         SchedBoundary *Zone) const {
+  // Initialize the candidate if needed.
+  if (!Cand.isValid()) {
+    TryCand.Reason = FirstValid;
+    return true;
+  }
+
+  // Bias PhysReg Defs and copies to their uses and defined respectively.
+  if (tryGreater(biasPhysReg(TryCand.SU, TryCand.AtTop),
+                 biasPhysReg(Cand.SU, Cand.AtTop), TryCand, Cand, PhysReg))
+    return TryCand.Reason != NoCand;
+
+  // Avoid exceeding the target's limit.
+  if (DAG->isTrackingPressure() &&
+      tryPressure(TryCand.RPDelta.Excess, Cand.RPDelta.Excess, TryCand, Cand,
+                  RegExcess, TRI, DAG->MF))
+    return TryCand.Reason != NoCand;
+
+  // We only compare a subset of features when comparing nodes between
+  // Top and Bottom boundary. Some properties are simply incomparable, in many
+  // other instances we should only override the other boundary if something
+  // is a clear good pick on one boundary. Skip heuristics that are more
+  // "tie-breaking" in nature.
+  bool SameBoundary = Zone != nullptr;
+  if (SameBoundary) {
+    // For loops that are acyclic path limited, aggressively schedule for
+    // latency. Within an single cycle, whenever CurrMOps > 0, allow normal
+    // heuristics to take precedence.
+    if (Rem.IsAcyclicLatencyLimited && !Zone->getCurrMOps() &&
+        tryLatency(TryCand, Cand, *Zone))
+      return TryCand.Reason != NoCand;
+
+    // Prioritize instructions that read unbuffered resources by stall cycles.
+    if (tryLess(Zone->getLatencyStallCycles(TryCand.SU),
+                Zone->getLatencyStallCycles(Cand.SU), TryCand, Cand, Stall))
+      return TryCand.Reason != NoCand;
+  }
+
+  // Keep clustered nodes together to encourage downstream peephole
+  // optimizations which may reduce resource requirements.
+  //
+  // This is a best effort to set things up for a post-RA pass. Optimizations
+  // like generating loads of multiple registers should ideally be done within
+  // the scheduler pass by combining the loads during DAG postprocessing.
+  unsigned CandZoneCluster = Cand.AtTop ? TopClusterID : BotClusterID;
+  unsigned TryCandZoneCluster = TryCand.AtTop ? TopClusterID : BotClusterID;
+  bool CandIsClusterSucc =
+      isTheSameCluster(CandZoneCluster, Cand.SU->ParentClusterIdx);
+  bool TryCandIsClusterSucc =
+      isTheSameCluster(TryCandZoneCluster, TryCand.SU->ParentClusterIdx);
+
+  if (tryGreater(TryCandIsClusterSucc, CandIsClusterSucc, TryCand, Cand,
+                 Cluster))
+    return TryCand.Reason != NoCand;
+
+  if (SameBoundary) {
+    // Weak edges are for clustering and other constraints.
+    if (tryLess(getWeakLeft(TryCand.SU, TryCand.AtTop),
+                getWeakLeft(Cand.SU, Cand.AtTop), TryCand, Cand, Weak))
+      return TryCand.Reason != NoCand;
+  }
+
+  // Avoid increasing the max pressure of the entire region.
+  if (DAG->isTrackingPressure() &&
+      tryPressure(TryCand.RPDelta.CurrentMax, Cand.RPDelta.CurrentMax, TryCand,
+                  Cand, RegMax, TRI, DAG->MF))
+    return TryCand.Reason != NoCand;
+
+  if (SameBoundary) {
+    // Avoid critical resource consumption and balance the schedule.
+    TryCand.initResourceDelta(DAG, SchedModel);
+    if (tryLess(TryCand.ResDelta.CritResources, Cand.ResDelta.CritResources,
+                TryCand, Cand, ResourceReduce))
+      return TryCand.Reason != NoCand;
+    if (tryGreater(TryCand.ResDelta.DemandedResources,
+                   Cand.ResDelta.DemandedResources, TryCand, Cand,
+                   ResourceDemand))
+      return TryCand.Reason != NoCand;
+
+    // Avoid serializing long latency dependence chains.
+    // For acyclic path limited loops, latency was already checked above.
+    if (!RegionPolicy.DisableLatencyHeuristic && TryCand.Policy.ReduceLatency &&
+        !Rem.IsAcyclicLatencyLimited && tryLatency(TryCand, Cand, *Zone))
+      return TryCand.Reason != NoCand;
+
+    // Fall through to original instruction order.
+    if ((Zone->isTop() && TryCand.SU->NodeNum < Cand.SU->NodeNum) ||
+        (!Zone->isTop() && TryCand.SU->NodeNum > Cand.SU->NodeNum)) {
+      TryCand.Reason = NodeOrder;
+      return true;
+    }
+  }
+
+  return false;
+}
+
+AMDGPUMLPostSchedStrategy::AMDGPUMLPostSchedStrategy(
+    const MachineSchedContext *C)
+    : PostGenericScheduler(C) {}
+
+bool AMDGPUMLPostSchedStrategy::tryCandidate(SchedCandidate &Cand,
+                                             SchedCandidate &TryCand) {
+  // Initialize the candidate if needed.
+  if (!Cand.isValid()) {
+    TryCand.Reason = FirstValid;
+    return true;
+  }
+
+  // Fall through to original instruction order.
+  // This effectively only enables hazard checking for post-RA scheduling.
+  if (TryCand.SU->NodeNum < Cand.SU->NodeNum) {
+    TryCand.Reason = NodeOrder;
+    return true;
+  }
+
+  return false;
+}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h b/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h
new file mode 100644
index 0000000000000..fd13b57a28f43
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPUMLSchedStrategy.h
@@ -0,0 +1,43 @@
+//===-- AMDGPUMLSchedStrategy.h - ML-focused Scheduler Strategy -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// ML-focused scheduling strategy for AMDGPU.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
+#define LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
+
+#include "GCNSchedStrategy.h"
+#include "llvm/CodeGen/MachineScheduler.h"
+
+namespace llvm {
+
+class AMDGPUMLSchedStrategy final : public GCNSchedStrategy {
+protected:
+  bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand,
+                    SchedBoundary *Zone) const override;
+
+public:
+  AMDGPUMLSchedStrategy(const MachineSchedContext *C);
+
+  void initialize(ScheduleDAGMI *DAG) override;
+};
+
+class AMDGPUMLPostSchedStrategy : public PostGenericScheduler {
+protected:
+  bool tryCandidate(SchedCandidate &Cand, SchedCandidate &TryCand) override;
+
+public:
+  AMDGPUMLPostSchedStrategy(const MachineSchedContext *C);
+};
+
+} // End namespace llvm
+
+#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
\ No newline at end of file
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index e5a35abe6da6b..f2fd137e4dd6f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -24,6 +24,7 @@
 #include "AMDGPUIGroupLP.h"
 #include "AMDGPUISelDAGToDAG.h"
 #include "AMDGPULowerVGPREncoding.h"
+#include "AMDGPUMLSchedStrategy.h"
 #include "AMDGPUMacroFusion.h"
 #include "AMDGPUPerfHintAnalysis.h"
 #include "AMDGPUPreloadKernArgProlog.h"
@@ -636,6 +637,11 @@ static ScheduleDAGInstrs *createSIMachineScheduler(MachineSchedContext *C) {
   return new SIScheduleDAGMI(C);
 }
 
+static bool isMLWorkload(const Function &F) {
+  Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
+  return WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
+}
+
 static ScheduleDAGInstrs *
 createGCNMaxOccupancyMachineScheduler(MachineSchedContext *C) {
   const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -659,6 +665,11 @@ createGCNMaxILPMachineScheduler(MachineSchedContext *C) {
   return DAG;
 }
 
+static ScheduleDAGInstrs *createGCNMLMachineScheduler(MachineSchedContext *C) {
+  return new GCNScheduleDAGMILive(C,
+                                  std::make_unique<AMDGPUMLSchedStrategy>(C));
+}
+
 static ScheduleDAGInstrs *
 createGCNMaxMemoryClauseMachineScheduler(MachineSchedContext *C) {
   const GCNSubtarget &ST = C->MF->getSubtarget<GCNSubtarget>();
@@ -1170,6 +1181,9 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
   if (ST.enableSIScheduler())
     return createSIMachineScheduler(C);
 
+  if (isMLWorkload(C->MF->getFunction()))
+    return createGCNMLMachineScheduler(C);
+
   Attribute SchedStrategyAttr =
       C->MF->getFunction().getFnAttribute("amdgpu-sched-strategy");
   StringRef SchedStrategy = SchedStrategyAttr.isValid()
@@ -1191,11 +1205,19 @@ GCNTargetMachine::createMachineScheduler(MachineSchedContext *C) const {
   if (SchedStrategy == "iterative-maxocc")
     return createIterativeGCNMaxOccupancyMachineScheduler(C);
 
+  if (SchedStrategy == "ml")
+    return createGCNMLMachineScheduler(C);
+
   return createGCNMaxOccupancyMachineScheduler(C);
 }
 
 ScheduleDAGInstrs *
 GCNTargetMachine::createPostMachineScheduler(MachineSchedContext *C) const {
+  if (isMLWorkload(C->MF->getFunction()))
+    return new GCNPostScheduleDAGMILive(
+        C, std::make_unique<AMDGPUMLPostSchedStrategy>(C),
+        /*RemoveKillFlags=*/true);
+
   ScheduleDAGMI *DAG =
       new GCNPostScheduleDAGMILive(C, std::make_unique<PostGenericScheduler>(C),
                                    /*RemoveKillFlags=*/true);
diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt
index 4baae51e021c5..1dfa070cf84de 100644
--- a/llvm/lib/Target/AMDGPU/CMakeLists.txt
+++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt
@@ -89,6 +89,7 @@ add_llvm_target(AMDGPUCodeGen
   AMDGPUMacroFusion.cpp
   AMDGPUMCInstLower.cpp
   AMDGPUMemoryUtils.cpp
+  AMDGPUMLSchedStrategy.cpp
   AMDGPUIGroupLP.cpp
   AMDGPULowerVGPREncoding.cpp
   AMDGPUMCResourceInfo.cpp
diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
index c8ce3aab3f303..b9362c41cdb7c 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp
@@ -184,7 +184,7 @@ static bool canUsePressureDiffs(const SUnit &SU) {
   return true;
 }
 
-static void getRegisterPressures(
+void GCNSchedStrategy::getRegisterPressures(
     bool AtTop, const RegPressureTracker &RPTracker, SUnit *SU,
     std::vector<unsigned> &Pressure, std::vector<unsigned> &MaxPressure,
     GCNDownwardRPTracker &DownwardTracker, GCNUpwardRPTracker &UpwardTracker,
@@ -192,7 +192,7 @@ static void getRegisterPressures(
   // getDownwardPressure() and getUpwardPressure() make temporary changes to
   // the tracker, so we need to pass those function a non-const copy.
   RegPressureTracker &TempTracker = const_cast<RegPressureTracker &>(RPTracker);
-  if (!GCNTrackers) {
+  if (!useGCNTrackers()) {
     AtTop
         ? TempTracker.getDownwardPressure(SU->getInstr(), Pressure, MaxPressure)
         : TempTracker.getUpwardPressure(SU->getInstr(), Pressure, MaxPressure);
@@ -244,7 +244,7 @@ void GCNSchedStrategy::initCandidate(SchedCandidate &Cand, SUnit *SU,
   //
   // In EXPENSIVE_CHECKS, we always query RPTracker to verify the results of
   // PressureDiffs.
-  if (AtTop || !canUsePressureDiffs(*SU) || GCNTrackers) {
+  if (AtTop || !canUsePressureDiffs(*SU) || useGCNTrackers()) {
     getRegisterPressures(AtTop, RPTracker, SU, Pressure, MaxPressure,
                          DownwardTracker, UpwardTracker, DAG, SRI);
   } else {
@@ -388,7 +388,7 @@ void GCNSchedStrategy::pickNodeFromQueue(SchedBoundary &Zone,
   unsigned VGPRPressure = 0;
   IsPending = false;
   if (DAG->isTrackingPressure()) {
-    if (!GCNTrackers) {
+    if (!useGCNTrackers()) {
       SGPRPressure = Pressure[AMDGPU::RegisterPressureSets::SReg_32];
       VGPRPressure = Pressure[AMDGPU::RegisterPressureSets::VGPR_32];
     } else {
@@ -611,7 +611,7 @@ SUnit *GCNSchedStrategy::pickNode(bool &IsTopNode) {
 }
 
 void GCNSchedStrategy::schedNode(SUnit *SU, bool IsTopNode) {
-  if (GCNTrackers) {
+  if (useGCNTrackers()) {
     MachineInstr *MI = SU->getInstr();
     IsTopNode ? (void)DownwardTracker.advance(MI, false)
               : UpwardTracker.recede(*MI);
@@ -693,7 +693,7 @@ GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
   SchedStages.push_back(GCNSchedStageID::UnclusteredHighRPReschedule);
   SchedStages.push_back(GCNSchedStageID::ClusteredLowOccupancyReschedule);
   SchedStages.push_back(GCNSchedStageID::PreRARematerialize);
-  GCNTrackers = GCNTrackers & !IsLegacyScheduler;
+  UseGCNTrackers = GCNTrackers & !IsLegacyScheduler;
 }
 
 GCNMaxILPSchedStrategy::GCNMaxILPSchedStrategy(const MachineSchedContext *C)
@@ -1118,9 +1118,10 @@ void GCNScheduleDAGMILive::finalizeSchedule() {
 void GCNScheduleDAGMILive::runSchedStages() {
   LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
 
+  GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
   if (!Regions.empty()) {
     BBLiveInMap = getRegionLiveInMap();
-    if (GCNTrackers)
+    if (S.useGCNTrackers())
       RegionLiveOuts.buildLiveRegMap();
   }
 
@@ -1132,7 +1133,6 @@ void GCNScheduleDAGMILive::runSchedStages() {
   }
 #endif
 
-  GCNSchedStrategy &S = static_cast<GCNSchedStrategy &>(*SchedImpl);
   while (S.advanceStage()) {
     auto Stage = createSchedStage(S.getCurrentStage());
     if (!Stage->initGCNSchedStage())
@@ -1148,7 +1148,7 @@ void GCNScheduleDAGMILive::runSchedStages() {
         continue;
       }
 
-      if (GCNTrackers) {
+      if (S.useGCNTrackers()) {
         GCNDownwardRPTracker *DownwardTracker = S.getDownwardTracker();
         GCNUpwardRPTracker *UpwardTracker = S.getUpwardTracker();
         GCNRPTracker::LiveRegSet *RegionLiveIns =
@@ -1297,7 +1297,7 @@ bool PreRARematStage::initGCNSchedStage() {
 
   // Rematerialize identified instructions and update scheduler's state.
   rematerialize();
-  if (GCNTrackers)
+  if (S.useGCNTrackers())
     DAG.RegionLiveOuts.buildLiveRegMap();
   REMAT_DEBUG({
     dbgs() << "Retrying function scheduling with new min. occupancy of "
diff --git a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
index 95a931b9beb2a..367f47c3ca4ae 100644
--- a/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
+++ b/llvm/lib/Target/AMDGPU/GCNSchedStrategy.h
@@ -70,6 +70,13 @@ class GCNSchedStrategy : public GenericScheduler {
   void printCandidateDecision(const SchedCandidate &Current,
                               const SchedCandidate &Preferred);
 
+  void getRegisterPressures(bool AtTop, const RegPressureTracker &RPTracker,
+                            SUnit *SU, std::vector<unsigned> &Pressure,
+                            std::vector<unsigned> &MaxPressure,
+                            GCNDownwardRPTracker &DownwardTracker,
+                            GCNUpwardRPTracker &UpwardTracker,
+                            ScheduleDAGMI *DAG, const SIRegisterInfo *SRI);
+
   std::vector<unsigned> Pressure;
 
   std::vector<unsigned> MaxPressure;
@@ -94,6 +101,8 @@ class GCNSchedStrategy : public GenericScheduler {
   // GCN RP Tracker for botttom-up scheduling
   mutable GCNUpwardRPTracker UpwardTracker;
 
+  bool UseGCNTrackers = false;
+
 public:
   // schedule() have seen register pressure over the critical limits and had to
   // track register pressure for actual scheduling heuristics.
@@ -141,6 +150,8 @@ class GCNSchedStrategy : public GenericScheduler {
 
   bool hasNextStage() const;
 
+  bool useGCNTrackers() const { return UseGCNTrackers; }
+
   GCNSchedStageID getNextStage() const;
 
   GCNDownwardRPTracker *getDownwardTracker() { return &DownwardTracker; }
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
index c8bbcbbd76928..13a20bdee4f06 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
@@ -331,6 +331,14 @@ void GCNSubtarget::overrideSchedPolicy(MachineSchedPolicy &Policy,
   Policy.OnlyTopDown = false;
   Policy.OnlyBottomUp = false;
 
+  const Function &F = Region.RegionBegin->getMF()->getFunction();
+  Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
+  bool IsMLWorkload =
+      WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
+  // Always schedule top-down for better blancing of HW resource usage.
+  if (IsMLWorkload)
+    Policy.OnlyTopDown = true;
+
   // Enabling ShouldTrackLaneMasks crashes the SI Machine Scheduler.
   if (!enableSIScheduler())
     Policy.ShouldTrackLaneMasks = true;
diff --git a/llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir b/llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir
new file mode 100644
index 0000000000000..bb82c7364d0ff
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/ml-sched-effective-stall.mir
@@ -0,0 +1,130 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 6
+# RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -run-pass=machine-scheduler -verify-misched %s -o - | FileCheck -check-prefix=DEFAULT %s
+# RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -run-pass=machine-scheduler -amdgpu-sched-strategy=ml -verify-misched %s -o - | FileCheck -check-prefix=ML %s
+
+# Pre-commit test for stall heuristic
+
+--- |
+  define void @with_ml_workload_attr() #0 { ret void }
+  define void @without_ml_workload_attr() #1 { ret void }
+
+  attributes #0 = { "amdgpu-workload-type"="ml" "amdgpu-waves-per-eu"="1,1" }
+  attributes #1 = { "amdgpu-waves-per-eu"="1,1" }
+...
+
+---
+name: with_ml_workload_attr
+tracksRegLiveness: true
+body: |
+  bb.0:
+    ; DEFAULT-LABEL: name: with_ml_workload_attr
+    ; DEFAULT: [[DEF:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF1:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF2:%[0-9]+]]:vreg_256_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF3:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF4:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF5:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF6:%[0-9]+]]:vreg_512_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF7:%[0-9]+]]:vreg_256_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF8:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF9:%[0-9]+]]:vgpr_32_lo256 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[DEF10:%[0-9]+]]:vreg_64_align2 = IMPLICIT_DEF
+    ; DEFAULT-NEXT: [[GLOBAL_LOAD_DWORDX2_:%[0-9]+]]:vreg_64_align2 = GLOBAL_LOAD_DWORDX2 [[DEF10]], 0, 0, implicit $exec
+    ; DEFAULT-NEXT: early-clobber %13:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF]], [[DEF1]], 0, [[DEF2]], [[DEF3]], [[DEF4]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
+    ; DEFAULT-NEXT: [[V_PK_ADD_F32_:%[0-9]+]]:vreg_64_align2 = V_PK_ADD_F32 8, [[GLOBAL_LOAD_DWORDX2_]], 8, [[GLOBAL_LOAD_DWORDX2_]], 0, 0, 0, 0, 0, implicit $mode, implicit $exec
+    ; DEFAULT-NEXT: early-clobber %14:vreg_256_align2 = V_WMMA_SCALE_F32_16X16X128_F8F6F4_f8_f8_w32_threeaddr [[DEF5]], [[DEF6]], 0, [[DEF7]], [[DEF8]], [[DEF9]], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, implicit $exec
+    ; DEFAULT-NEXT: S_ENDP...
[truncated]

@kerbowa kerbowa requested a review from proaditya December 2, 2025 18:37

} // End namespace llvm

#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUMLSCHEDSTRATEGY_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing empty line at EoF

bool IsMLWorkload =
WorkloadAttr.isValid() && WorkloadAttr.getValueAsString() == "ml";
// Always schedule top-down for better blancing of HW resource usage.
if (IsMLWorkload)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this gonna interact with front end?

@shiltian shiltian requested review from arsenm and perlfu December 2, 2025 19:10
@@ -636,6 +637,11 @@ static ScheduleDAGInstrs *createSIMachineScheduler(MachineSchedContext *C) {
return new SIScheduleDAGMI(C);
}

static bool isMLWorkload(const Function &F) {
Attribute WorkloadAttr = F.getFnAttribute("amdgpu-workload-type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use the existing scheduler override attribute?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants