Skip to content

Commit 6b70e71

Browse files
authored
[NVWS] Update RewritePartitionDependecies to insert arefs (#7645)
Take two for [PR7561](triton-lang/triton#7561) after it was reverted in [PR7611](triton-lang/triton#7561) * Integrates Remove Rewrite multiplicity ([PR7371](triton-lang/triton#7371)) * Teaches `RewritePartitionDependencies` to insert arefs. * **new**: add `aref.destroy` to emit `ttng.inval_barrier` to eliminate UB when the same smem is reused later in aref mbarrier <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. -->
1 parent a837a04 commit 6b70e71

File tree

13 files changed

+600
-915
lines changed

13 files changed

+600
-915
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def TritonGPURewritePartitionDependencies : Pass<"tritongpu-rewrite-partition-de
135135
"mlir::triton::gpu::TritonGPUDialect",
136136
"mlir::scf::SCFDialect",
137137
"mlir::arith::ArithDialect",
138-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
138+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
139+
"mlir::triton::nvws::NVWSDialect"
139140
];
140141
}
141142

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Pass/Pass.h"
44
#include "mlir/Pass/PassManager.h"
55
#include "mlir/Transforms/Passes.h"
6+
#include "third_party/nvidia/include/Dialect/NVWS/Transforms/Passes.h"
67
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
78
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
89

@@ -42,6 +43,8 @@ void AutomaticWarpSpecialization::runOnOperation() {
4243
pm.addPass(createSCCPPass());
4344
pm.addPass(createCSEPass());
4445
pm.addPass(createTritonGPUPartitionLoops());
46+
pm.addPass(createNVWSLowerAref());
47+
pm.addPass(createNVWSLowerWarpGroup());
4548
if (failed(runPipeline(pm, getOperation())))
4649
return signalPassFailure();
4750

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -191,57 +191,6 @@ LogicalResult WarpSchedule::verify(scf::ForOp loop) const {
191191
if (failed)
192192
return failure();
193193

194-
// Within a loop iteration, the partitions must form a DAG. For example, the
195-
// following is invalid:
196-
//
197-
// scf.for %i = %lb to %ub step %step
198-
// %0 = op_a() {ttg.partition = 0}
199-
// %1 = op_b(%0) {ttg.partition = 1}
200-
// op_c(%1) {ttg.partition = 0}
201-
//
202-
PartitionGraph graph(loop, *this);
203-
for (auto it = llvm::scc_begin(graph); !it.isAtEnd(); ++it) {
204-
if (!it.hasCycle())
205-
continue;
206-
InFlightDiagnostic diag =
207-
mlir::emitWarning(loop.getLoc(), "warp schedule contains a cycle");
208-
for (auto [node, use] : *it) {
209-
assert(use && "already checked that the root partition has no ancestors");
210-
diag.attachNote(use->getOwner()->getLoc())
211-
<< "operation in partition #" << node->partition->getIndex()
212-
<< " uses value defined in partition #"
213-
<< opToPartition.at(use->get().getDefiningOp())->getIndex();
214-
}
215-
return failure();
216-
}
217-
218-
// Each partition's stage must be strictly less than all of its consumers plus
219-
// the distance.
220-
for (Partition &partition : getPartitions()) {
221-
bool failed = false;
222-
auto callback = [&](OpResult output, OpOperand &use, unsigned distance) {
223-
Operation *user = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
224-
const Partition *consumer = opToPartition.at(user);
225-
if (partition.getStage() < consumer->getStage() + distance)
226-
return;
227-
InFlightDiagnostic diag =
228-
mlir::emitWarning(loop.getLoc(), "partition #")
229-
<< partition.getIndex() << " has stage " << partition.getStage()
230-
<< " but is consumed by partition #" << consumer->getIndex()
231-
<< " with stage " << consumer->getStage() << " at distance "
232-
<< distance;
233-
diag.attachNote(use.getOwner()->getLoc())
234-
<< "use of value defined in partition #" << partition.getIndex()
235-
<< " at " << distance << " iterations in the future";
236-
diag.attachNote(output.getLoc())
237-
<< "value defined here in partition #" << partition.getIndex();
238-
failed = true;
239-
};
240-
iterateUses(loop, &partition, callback);
241-
if (failed)
242-
return failure();
243-
}
244-
245194
return success();
246195
}
247196

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,4 @@ void PartitionLoops::runOnOperation() {
499499
if (failed(partitionLoop(loop)))
500500
return signalPassFailure();
501501
}
502-
503-
OpPassManager pm;
504-
pm.addPass(mlir::triton::createNVWSLowerWarpGroup());
505-
506-
if (failed(runPipeline(pm, getOperation())))
507-
return signalPassFailure();
508502
}

0 commit comments

Comments
 (0)