Skip to content

Commit 33462c8

Browse files
authored
[AMD] Use ttg.mask in AMD StreamPipeliner (#7620)
This PR prepares the infrastructure for handling masked operations in AMD backend, by moving MaskOp handling functions into the shared pipeline utilities in order to utilize them.
1 parent 4b1e177 commit 33462c8

File tree

4 files changed

+99
-68
lines changed

4 files changed

+99
-68
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ bool isOuterLoop(scf::ForOp forOp);
6868
/// Function to mask operations during scheduling.
6969
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);
7070

71+
/// Wrap the operation into a MaskOp using the provided predicate, enabling high
72+
/// level predication abstraction during pipelining.
73+
Operation *wrapInMaskOp(RewriterBase &rewriter, Operation *op, Value pred);
74+
75+
// Utilize high level predication abstraction to perform optimizations before
76+
// lowering to predicated operations
77+
void resolveMaskOp(ModuleOp moduleOp,
78+
DenseSet<triton::gpu::MaskOp> &peeledMaskOps);
79+
7180
// Return true if the given ForOp has the attribute
7281
// `tt.disallow_acc_multi_buffer` set to true.
7382
bool getDisallowAccMultiBuffer(scf::ForOp forOp);

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#include "mlir/Analysis/TopologicalSortUtils.h"
33
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
44
#include "mlir/Dialect/Tensor/IR/Tensor.h"
5+
#include "mlir/Dialect/UB/IR/UBOps.h"
56
#include "mlir/IR/ImplicitLocOpBuilder.h"
67
#include "mlir/IR/TypeUtilities.h"
78
#include "mlir/Interfaces/SideEffectInterfaces.h"
89
#include "mlir/Support/LLVM.h"
10+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
911
#include "triton/Analysis/AxisInfo.h"
1012
#include "triton/Dialect/Triton/IR/Utility.h"
1113
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -279,6 +281,69 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
279281
return op;
280282
}
281283

284+
Operation *mlir::triton::wrapInMaskOp(RewriterBase &rewriter, Operation *op,
285+
Value pred) {
286+
auto mask =
287+
rewriter.create<ttg::MaskOp>(op->getLoc(), op->getResultTypes(), pred);
288+
rewriter.createBlock(&mask->getRegion(0));
289+
rewriter.setInsertionPointToStart(&mask->getRegion(0).front());
290+
auto newOp = rewriter.clone(*op);
291+
rewriter.create<ttg::MaskReturnOp>(op->getLoc(), newOp->getResults());
292+
op->replaceAllUsesWith(mask->getResults());
293+
rewriter.eraseOp(op);
294+
return mask;
295+
}
296+
297+
void mlir::triton::resolveMaskOp(ModuleOp moduleOp,
298+
DenseSet<ttg::MaskOp> &peeledMaskOps) {
299+
IRRewriter rewriter(moduleOp);
300+
301+
// Canonicalize the IR to simplify the arithmetic ops defining the mask
302+
auto arithDialect =
303+
moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>();
304+
RewritePatternSet patterns(moduleOp.getContext());
305+
arithDialect->getCanonicalizationPatterns(patterns);
306+
if (mlir::applyPatternsGreedily(moduleOp, std::move(patterns)).failed())
307+
return llvm::report_fatal_error("Failed to canonicalize the IR");
308+
309+
// Prune all the statically dead mask ops in the epilogue. This is a
310+
// hack, ideally we should do it for all the mask ops, but it is incorrect if
311+
// we have speculatively executed async cp operations that will store to shmem
312+
// even if the mask is false.
313+
for (auto maskOp : peeledMaskOps) {
314+
rewriter.setInsertionPoint(maskOp);
315+
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
316+
Operation *op = &maskOp.getBody()->front();
317+
if (isConstantIntValue(maskOp.getPred(), 0)) {
318+
if (op->getNumResults() > 0) {
319+
SmallVector<Value> results;
320+
for (auto result : op->getResults()) {
321+
auto poisonOp = rewriter.create<mlir::ub::PoisonOp>(
322+
op->getLoc(), result.getType());
323+
results.push_back(poisonOp);
324+
}
325+
op->replaceAllUsesWith(results);
326+
}
327+
op->erase();
328+
}
329+
}
330+
}
331+
332+
SmallVector<ttg::MaskOp> maskOps;
333+
moduleOp->walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); });
334+
for (auto maskOp : maskOps) {
335+
rewriter.setInsertionPoint(maskOp);
336+
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
337+
Operation *op = &maskOp.getBody()->front();
338+
rewriter.moveOpBefore(op, maskOp);
339+
op = triton::predicateOp(rewriter, op, maskOp.getPred());
340+
}
341+
maskOp->replaceAllUsesWith(
342+
maskOp.getBody()->getTerminator()->getOperands());
343+
maskOp->erase();
344+
}
345+
}
346+
282347
// Return true if the given ForOp has the attribute
283348
// `tt.disallow_acc_multi_buffer` set to true.
284349
bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2-
#include "mlir/Dialect/UB/IR/UBOps.h"
32
#include "mlir/IR/TypeUtilities.h"
43
#include "mlir/Interfaces/SideEffectInterfaces.h"
54
#include "mlir/Support/LLVM.h"
@@ -43,67 +42,6 @@ static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) {
4342
}
4443
}
4544

46-
static Operation *wrapInMaskOp(RewriterBase &rewriter, Operation *op,
47-
Value pred) {
48-
auto mask = rewriter.create<MaskOp>(op->getLoc(), op->getResultTypes(), pred);
49-
rewriter.createBlock(&mask->getRegion(0));
50-
rewriter.setInsertionPointToStart(&mask->getRegion(0).front());
51-
auto newOp = rewriter.clone(*op);
52-
rewriter.create<MaskReturnOp>(op->getLoc(), newOp->getResults());
53-
op->replaceAllUsesWith(mask->getResults());
54-
rewriter.eraseOp(op);
55-
return mask;
56-
}
57-
58-
static void resolveMaskOp(ModuleOp moduleOp, DenseSet<MaskOp> &peeledMaskOps) {
59-
IRRewriter rewriter(moduleOp);
60-
61-
// Canonicalize the IR to simplify the arithmetic ops defining the mask
62-
auto arithDialect =
63-
moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>();
64-
RewritePatternSet patterns(moduleOp.getContext());
65-
arithDialect->getCanonicalizationPatterns(patterns);
66-
if (applyPatternsGreedily(moduleOp, std::move(patterns)).failed())
67-
return llvm::report_fatal_error("Failed to canonicalize the IR");
68-
69-
// Prune all the statically dead mask ops in the epilogue. This is a
70-
// hack, ideally we should do it for all the mask ops, but it is incorrect if
71-
// we have speculatively executed async cp operations that will store to shmem
72-
// even if the mask is false.
73-
for (auto maskOp : peeledMaskOps) {
74-
rewriter.setInsertionPoint(maskOp);
75-
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
76-
Operation *op = &maskOp.getBody()->front();
77-
if (isConstantIntValue(maskOp.getPred(), 0)) {
78-
if (op->getNumResults() > 0) {
79-
SmallVector<Value> results;
80-
for (auto result : op->getResults()) {
81-
auto poisonOp =
82-
rewriter.create<ub::PoisonOp>(op->getLoc(), result.getType());
83-
results.push_back(poisonOp);
84-
}
85-
op->replaceAllUsesWith(results);
86-
}
87-
op->erase();
88-
}
89-
}
90-
}
91-
92-
SmallVector<MaskOp> maskOps;
93-
moduleOp->walk([&](MaskOp maskOp) { maskOps.push_back(maskOp); });
94-
for (auto maskOp : maskOps) {
95-
rewriter.setInsertionPoint(maskOp);
96-
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
97-
Operation *op = &maskOp.getBody()->front();
98-
rewriter.moveOpBefore(op, maskOp);
99-
op = triton::predicateOp(rewriter, op, maskOp.getPred());
100-
}
101-
maskOp->replaceAllUsesWith(
102-
maskOp.getBody()->getTerminator()->getOperands());
103-
maskOp->erase();
104-
}
105-
}
106-
10745
static bool hasMMAv5WaitsInLastStage(scf::ForOp forOp,
10846
CoarseSchedule &schedule) {
10947
int maxStage = schedule.getNumStages() - 1;

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Operation *streamPredication(RewriterBase &rewriter, Operation *op,
5353
ifOp.getElseBodyBuilder().create<scf::YieldOp>(loc, dotOp->getOperand(2));
5454
return ifOp;
5555
}
56-
return tt::predicateOp(rewriter, op, pred);
56+
return tt::wrapInMaskOp(rewriter, op, pred);
5757
}
5858

5959
//===----------------------------------------------------------------------===//
@@ -974,9 +974,9 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
974974
}
975975
} // namespace ChainedDotSchedule
976976

977-
LogicalResult pipelineLoop(scf::ForOp forOp, int numStages, int globalPrefetch,
978-
int localPrefetch, bool useAsyncCopy,
979-
bool waitAtTail) {
977+
FailureOr<scf::ForOp> pipelineLoop(scf::ForOp forOp, int numStages,
978+
int globalPrefetch, int localPrefetch,
979+
bool useAsyncCopy, bool waitAtTail) {
980980

981981
triton::AMD::ModuleAxisInfoAnalysis axisInfoAnalysis(
982982
forOp->getParentOfType<ModuleOp>());
@@ -1019,9 +1019,24 @@ LogicalResult pipelineLoop(scf::ForOp forOp, int numStages, int globalPrefetch,
10191019
if (part != tt::PipeliningOption::PipelinerPart::Prologue)
10201020
return;
10211021

1022-
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
1022+
auto annotateLoad = [](Operation *loadOp) {
10231023
loadOp->setAttr("amd.pipeliner_part",
1024-
StringAttr::get(op->getContext(), "prologue"));
1024+
StringAttr::get(loadOp->getContext(), "prologue"));
1025+
};
1026+
1027+
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
1028+
annotateLoad(loadOp);
1029+
return;
1030+
}
1031+
// loadOp may be wrapped by a MaskOp as predicateFn execution
1032+
// precedes annotation
1033+
if (auto maskOp = dyn_cast<ttg::MaskOp>(op)) {
1034+
for (auto &innerOp : maskOp.getBody()->without_terminator()) {
1035+
if (auto loadOp = dyn_cast<tt::LoadOp>(&innerOp)) {
1036+
annotateLoad(loadOp);
1037+
return;
1038+
}
1039+
}
10251040
}
10261041
};
10271042
// Set the final schedule as our scheduling function
@@ -1077,6 +1092,10 @@ struct PipelinePass : impl::TritonAMDGPUStreamPipelineBase<PipelinePass> {
10771092
useAsyncCopy, waitAtTail);
10781093
}
10791094

1095+
// NOTE: Leave empty for now, until we utilize customEpiloguePeeling
1096+
DenseSet<ttg::MaskOp> peeledMaskOps;
1097+
tt::resolveMaskOp(moduleOp, peeledMaskOps);
1098+
10801099
if (useAsyncCopy) {
10811100
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> waitOps;
10821101
moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); });

0 commit comments

Comments
 (0)