Skip to content

Commit ddaa11f

Browse files
authored
Revert "[WS] Update RewritePartitionDepdencies to insert arefs" (#7611)
Reverts triton-lang/triton#7561 as it causes functional failures in our internal attention tests
1 parent 8386213 commit ddaa11f

File tree

13 files changed

+919
-542
lines changed

13 files changed

+919
-542
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def TritonGPURewritePartitionDependencies : Pass<"tritongpu-rewrite-partition-de
130130
"mlir::triton::gpu::TritonGPUDialect",
131131
"mlir::scf::SCFDialect",
132132
"mlir::arith::ArithDialect",
133-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
134-
"mlir::triton::nvws::NVWSDialect"
133+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
135134
];
136135
}
137136

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "mlir/Pass/Pass.h"
44
#include "mlir/Pass/PassManager.h"
55
#include "mlir/Transforms/Passes.h"
6-
#include "third_party/nvidia/include/Dialect/NVWS/Transforms/Passes.h"
76
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
87
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
98

@@ -43,8 +42,6 @@ void AutomaticWarpSpecialization::runOnOperation() {
4342
pm.addPass(createSCCPPass());
4443
pm.addPass(createCSEPass());
4544
pm.addPass(createTritonGPUPartitionLoops());
46-
pm.addPass(createNVWSLowerAref());
47-
pm.addPass(createNVWSLowerWarpGroup());
4845
if (failed(runPipeline(pm, getOperation())))
4946
return signalPassFailure();
5047

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,57 @@ LogicalResult WarpSchedule::verify(scf::ForOp loop) const {
191191
if (failed)
192192
return failure();
193193

194+
// Within a loop iteration, the partitions must form a DAG. For example, the
195+
// following is invalid:
196+
//
197+
// scf.for %i = %lb to %ub step %step
198+
// %0 = op_a() {ttg.partition = 0}
199+
// %1 = op_b(%0) {ttg.partition = 1}
200+
// op_c(%1) {ttg.partition = 0}
201+
//
202+
PartitionGraph graph(loop, *this);
203+
for (auto it = llvm::scc_begin(graph); !it.isAtEnd(); ++it) {
204+
if (!it.hasCycle())
205+
continue;
206+
InFlightDiagnostic diag =
207+
mlir::emitWarning(loop.getLoc(), "warp schedule contains a cycle");
208+
for (auto [node, use] : *it) {
209+
assert(use && "already checked that the root partition has no ancestors");
210+
diag.attachNote(use->getOwner()->getLoc())
211+
<< "operation in partition #" << node->partition->getIndex()
212+
<< " uses value defined in partition #"
213+
<< opToPartition.at(use->get().getDefiningOp())->getIndex();
214+
}
215+
return failure();
216+
}
217+
218+
// Each partition's stage must be strictly less than all of its consumers plus
219+
// the distance.
220+
for (Partition &partition : getPartitions()) {
221+
bool failed = false;
222+
auto callback = [&](OpResult output, OpOperand &use, unsigned distance) {
223+
Operation *user = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
224+
const Partition *consumer = opToPartition.at(user);
225+
if (partition.getStage() < consumer->getStage() + distance)
226+
return;
227+
InFlightDiagnostic diag =
228+
mlir::emitWarning(loop.getLoc(), "partition #")
229+
<< partition.getIndex() << " has stage " << partition.getStage()
230+
<< " but is consumed by partition #" << consumer->getIndex()
231+
<< " with stage " << consumer->getStage() << " at distance "
232+
<< distance;
233+
diag.attachNote(use.getOwner()->getLoc())
234+
<< "use of value defined in partition #" << partition.getIndex()
235+
<< " at " << distance << " iterations in the future";
236+
diag.attachNote(output.getLoc())
237+
<< "value defined here in partition #" << partition.getIndex();
238+
failed = true;
239+
};
240+
iterateUses(loop, &partition, callback);
241+
if (failed)
242+
return failure();
243+
}
244+
194245
return success();
195246
}
196247

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,4 +499,10 @@ void PartitionLoops::runOnOperation() {
499499
if (failed(partitionLoop(loop)))
500500
return signalPassFailure();
501501
}
502+
503+
OpPassManager pm;
504+
pm.addPass(mlir::triton::createNVWSLowerWarpGroup());
505+
506+
if (failed(runPipeline(pm, getOperation())))
507+
return signalPassFailure();
502508
}

0 commit comments

Comments
 (0)