Skip to content

Commit 4ea942b

Browse files
authored
[SWP] Attempt to move all scheduling logic to a scheduling pass (triton-lang#4618)
The main purpose is to move all scheduling related logic including scheduleDeps/scheduleDistOne/scheduleRemaining to a new loopScheduling pass. The new pass will generate (stage, cluster) attributes for each operation inside a loop. For operations that are created during createAsyncOps, we manually add the (stage, cluster) based on the attributes prior to lowering. In most cases, the lowered operations should have (stage, cluster) of the original loadOp, or should have (stage, cluster) of the first use of the loadOp. This patch also gets rid of prefetchCluster, instead it prefetches one stage before the actual use setStageCluster(forOp, wait, stageForFirstUse - 1, clusterForFirstUse + 1); setStageCluster(forOp, viewLoad, stageForFirstUse - 1, clusterForFirstUse + 1); comparing to schedule.insert(wait, numStages - 2, prefetchCluster); schedule.insert(viewLoad, numStages - 2, prefetchCluster); At end of createAsyncOps, we make sure all operations inside the loop have (stage, cluster) attributes. There is no need to maintain CoarseSchedule in SWP, instead we will just use (stage, cluster) attributes.
1 parent d06ec83 commit 4ea942b

File tree

19 files changed

+922
-594
lines changed

19 files changed

+922
-594
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,17 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
179179
"mlir::triton::TritonDialect"];
180180
}
181181

182+
def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> {
183+
let summary = "Generate loop scheduling for SWP";
184+
185+
let description = "This pass sets up stages and clustering for software pipelining.";
186+
187+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
188+
"mlir::triton::TritonDialect"];
189+
let options = [
190+
Option<"numStages", "num-stages",
191+
"int32_t", /*default*/"3",
192+
"number of pipeline stages">
193+
];
194+
}
182195
#endif

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace mlir {
88
namespace triton {
99

1010
static const char *kNumStagesAttrName = "tt.num_stages";
11+
static const char *kLoopStageAttrName = "loop.stage";
12+
static const char *kLoopClusterAttrName = "loop.cluster";
1113

1214
/// Function to mask operations during scheduling.
1315
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);
@@ -29,6 +31,11 @@ void addOps(scf::ForOp forOp, int stage,
2931
/// mutable.
3032
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3133
Value val);
34+
35+
// Return the minClusterId and maxClusterId for the given ForOp.
36+
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
37+
std::pair<int, int> getStageCluster(Operation *op);
38+
void setStageCluster(scf::ForOp &forOp, Operation *op, int stage, int cluster);
3239
} // namespace triton
3340
} // namespace mlir
3441

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,16 @@ class CoarseSchedule {
101101
createFinalSchedule(scf::ForOp forOp);
102102
void dump();
103103
bool empty() { return opToStageAndCluster.size() == 0; }
104+
void serialize(scf::ForOp &forOp);
105+
// Create a CoarseSchedule based on forOp's <stage, cluster>.
106+
void deSerialize(scf::ForOp &forOp);
104107
};
105108

109+
// Add dependencies of anchor ops to the coarse schedule. Schedule them to
110+
// the same stage and ordering cluster as the anchor op.
111+
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule,
112+
int numStages);
113+
106114
} // namespace triton
107115
} // namespace mlir
108116
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ bool isPureUnaryInlineAsm(Operation *op);
192192
// read the compute capability from the module attributes
193193
int getNVIDIAComputeCapability(Operation *module);
194194

195+
std::optional<mlir::triton::gpu::SharedEncodingAttr>
196+
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
197+
198+
bool loadIsMMAv3(Operation *loadOp);
195199
} // namespace mlir
196200

197201
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_triton_library(TritonGPUTransforms
33
Coalesce.cpp
44
F32DotTC.cpp
55
CombineTensorSelectAndIf.cpp
6+
LoopScheduling.cpp
67
ReduceDataDuplication.cpp
78
OptimizeAccumulatorInit.cpp
89
OptimizeDotOperands.cpp

0 commit comments

Comments
 (0)