Skip to content

Commit 167bdc8

Browse files
authored
[WS] fold rewrite-partition-dependencies into insert-aref (#8619)
folder `rewrite-partition-dependencies` pass into `insert-aref`. fold corresponding lit test into `insert-aref` lit tests - `insert-aref` now handles all aref insertio that follows producer-consumer model - it also follow `rewrite-partition-dependencies` logic where `get.enter` is inserted just before consumer in the same block as producer is located, instead right after producer as it was before. - insert-tmem-aref handles tmem aref, that follow ownership-transfer model.
1 parent 7135f4f commit 167bdc8

File tree

10 files changed

+754
-1038
lines changed

10 files changed

+754
-1038
lines changed

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ void setPartition(Operation *op, const SetVector<int> &partitionIds);
126126
void setPartitionOutputs(Operation *op,
127127
ArrayRef<SetVector<int>> partitionOutputsIds);
128128
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
129+
SetVector<int> getPartitionIds(OpOperand *use);
129130

130131
} // namespace mlir::triton::gpu
131132

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,25 +121,6 @@ def TritonGPUAutomaticWarpSpecialization : Pass<"tritongpu-automatic-warp-specia
121121
];
122122
}
123123

124-
def TritonGPURewritePartitionDependencies : Pass<"tritongpu-rewrite-partition-dependencies", "mlir::ModuleOp"> {
125-
let summary = "test pass for rewriting partition dependencies";
126-
127-
let description = [{
128-
The `tritongpu-rewrite-partition-dependencies` pass analyzes the partitions
129-
assigned to a loop and their SSA dependencies. It rewrites the dependencies
130-
to be passed through shared memory, applying multi-buffering according to
131-
the assigned stages of the partitions.
132-
}];
133-
134-
let dependentDialects = [
135-
"mlir::triton::gpu::TritonGPUDialect",
136-
"mlir::scf::SCFDialect",
137-
"mlir::arith::ArithDialect",
138-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
139-
"mlir::triton::nvws::NVWSDialect"
140-
];
141-
}
142-
143124
def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"> {
144125
let summary = "split scheduled loops into `ttg.warp_specialize`";
145126

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ add_triton_library(TritonGPUTransforms
3434
WarpSpecialization/PartitionBuilder.cpp
3535
WarpSpecialization/PartitionLoops.cpp
3636
WarpSpecialization/PartitionScheduling.cpp
37-
WarpSpecialization/RewritePartitionDependencies.cpp
3837

3938
DEPENDS
4039
TritonGPUTransformsIncGen

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ void AutomaticWarpSpecialization::runOnOperation() {
3737
pm.addPass(createTritonGPUPartitionScheduling());
3838
pm.addPass(createNVWSInsertAref());
3939
pm.addPass(createNVWSInsertTmemAref());
40-
pm.addPass(createTritonGPURewritePartitionDependencies());
4140
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4241
// FIXME: Re-enable integer range analysis once it is fixed.
4342
// pm.addPass(arith::createIntRangeOptimizationsPass());

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,18 @@ SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op) {
249249
return partitionOutputsIds;
250250
}
251251

252+
SetVector<int> getPartitionIds(OpOperand *use) {
253+
auto owner = use->getOwner();
254+
if (isa<scf::YieldOp>(owner)) {
255+
return getPartitionOutputs(owner->getParentOp())[use->getOperandNumber()];
256+
} else if (scf::ForOp forOp = dyn_cast<scf::ForOp>(owner)) {
257+
int idx = use->getOperandNumber() - forOp.getNumControlOperands();
258+
return idx >= 0 ? getPartitionOutputs(owner)[idx] : *getPartitionIds(forOp);
259+
} else {
260+
return *getPartitionIds(owner);
261+
}
262+
}
263+
252264
bool hasPartition(Operation *op) { return getPartitionIds(op) != std::nullopt; }
253265

254266
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)