Skip to content

Commit 4b9a8d6

Browse files
committed
Merge commit '8dd7ccb26c37ed2521a38a60e94dbff685c662cd'
2 parents 82b1f85 + 8dd7ccb commit 4b9a8d6

File tree

28 files changed

+1040
-176
lines changed

28 files changed

+1040
-176
lines changed

docs/programming-guide/chapter-3/debugging.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@ Using Third-party Tools
7777
For debugging on NVIDIA GPUs, `compute-sanitizer <https://docs.nvidia.com/cuda/compute-sanitizer/index.html>`_ is an effective tool for checking data races and memory access issues.
7878
To use it, prepend :code:`compute-sanitizer` to your command to run the Triton program.
7979

80-
For debugging on AMD GPUs, you may want to try the LLVM `AddressSanitizer <https://rocm.docs.amd.com/en/latest/conceptual/using-gpu-sanitizer.html>`_ for ROCm.
80+
For debugging on AMD GPUs, you may want to try the LLVM `AddressSanitizer <https://rocm.docs.amd.com/projects/llvm-project/en/latest/conceptual/using-gpu-sanitizer.html>`_ for ROCm.
8181

8282
For detailed visualization of memory access in Triton programs, consider using the `triton-viz <https://github.com/Deep-Learning-Profiling-Tools/triton-viz>`_ tool, which is agnostic to the underlying GPUs.

include/triton/Analysis/Allocation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy);
6363
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
6464
RankedTensorType dstTy);
6565

66+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
67+
RankedTensorType dstTy);
68+
69+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
70+
RankedTensorType dstTy);
6671
} // namespace triton
6772

6873
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ bool isOuterLoop(scf::ForOp forOp);
6868
/// Function to mask operations during scheduling.
6969
Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred);
7070

71+
/// Wrap the operation into a MaskOp using the provided predicate, enabling high
72+
/// level predication abstraction during pipelining.
73+
Operation *wrapInMaskOp(RewriterBase &rewriter, Operation *op, Value pred);
74+
75+
// Utilize high level predication abstraction to perform optimizations before
76+
// lowering to predicated operations
77+
void resolveMaskOp(ModuleOp moduleOp,
78+
DenseSet<triton::gpu::MaskOp> &peeledMaskOps);
79+
7180
// Return true if the given ForOp has the attribute
7281
// `tt.disallow_acc_multi_buffer` set to true.
7382
bool getDisallowAccMultiBuffer(scf::ForOp forOp);

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3535
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
3636

37+
namespace mlir::triton::nvidia_gpu::impl {
38+
LogicalResult verifyMMAv5Op(Operation *op);
39+
} // namespace mlir::triton::nvidia_gpu::impl
40+
3741
#define GET_ATTRDEF_CLASSES
3842
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
3943

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,9 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
5454
"setIsAsync",
5555
(ins "bool":$isAsync)>,
5656
];
57+
58+
let verify = [{
59+
return ::mlir::triton::nvidia_gpu::impl::verifyMMAv5Op($_op);
60+
}];
5761
}
5862
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ static unsigned getBitwidth(RankedTensorType ty) {
3939
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
4040
}
4141

42-
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43-
RankedTensorType dstTy) {
42+
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43+
RankedTensorType dstTy) {
4444
auto *ctx = srcTy.getContext();
4545
auto srcLayout = gpu::toLinearLayout(srcTy);
4646
auto dstLayout = gpu::toLinearLayout(dstTy);
@@ -52,8 +52,8 @@ static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
5252
return smem.getTotalOutDimSize() / reps;
5353
}
5454

55-
static unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56-
RankedTensorType dstTy) {
55+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
5757
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
5858
return getNumScratchElements(scratchConfig.paddedRepShape);
5959
}

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,13 @@ class SinkTMEMLoad : public OpRewritePattern<ttng::TMEMLoadOp> {
141141
return postDomInfo.properlyPostDominates(use->getOwner(), domOp);
142142
}))
143143
return failure();
144-
if (domOp == load->getNextNode()) {
144+
// In order to not re-ordering multiple tmem load in a loop, don't sink if
145+
// all the ops between the load and the domOp are tmem loads.
146+
Operation *nextNode = load->getNextNode();
147+
while (auto tmemLoad = dyn_cast<ttng::TMEMLoadOp>(nextNode)) {
148+
nextNode = tmemLoad->getNextNode();
149+
}
150+
if (domOp == nextNode) {
145151
// The load wasn't moved.
146152
return failure();
147153
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#include "mlir/Analysis/TopologicalSortUtils.h"
33
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
44
#include "mlir/Dialect/Tensor/IR/Tensor.h"
5+
#include "mlir/Dialect/UB/IR/UBOps.h"
56
#include "mlir/IR/ImplicitLocOpBuilder.h"
67
#include "mlir/IR/TypeUtilities.h"
78
#include "mlir/Interfaces/SideEffectInterfaces.h"
89
#include "mlir/Support/LLVM.h"
10+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
911
#include "triton/Analysis/AxisInfo.h"
1012
#include "triton/Dialect/Triton/IR/Utility.h"
1113
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -279,6 +281,69 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
279281
return op;
280282
}
281283

284+
Operation *mlir::triton::wrapInMaskOp(RewriterBase &rewriter, Operation *op,
285+
Value pred) {
286+
auto mask =
287+
rewriter.create<ttg::MaskOp>(op->getLoc(), op->getResultTypes(), pred);
288+
rewriter.createBlock(&mask->getRegion(0));
289+
rewriter.setInsertionPointToStart(&mask->getRegion(0).front());
290+
auto newOp = rewriter.clone(*op);
291+
rewriter.create<ttg::MaskReturnOp>(op->getLoc(), newOp->getResults());
292+
op->replaceAllUsesWith(mask->getResults());
293+
rewriter.eraseOp(op);
294+
return mask;
295+
}
296+
297+
void mlir::triton::resolveMaskOp(ModuleOp moduleOp,
298+
DenseSet<ttg::MaskOp> &peeledMaskOps) {
299+
IRRewriter rewriter(moduleOp);
300+
301+
// Canonicalize the IR to simplify the arithmetic ops defining the mask
302+
auto arithDialect =
303+
moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>();
304+
RewritePatternSet patterns(moduleOp.getContext());
305+
arithDialect->getCanonicalizationPatterns(patterns);
306+
if (mlir::applyPatternsGreedily(moduleOp, std::move(patterns)).failed())
307+
return llvm::report_fatal_error("Failed to canonicalize the IR");
308+
309+
// Prune all the statically dead mask ops in the epilogue. This is a
310+
// hack, ideally we should do it for all the mask ops, but it is incorrect if
311+
// we have speculatively executed async cp operations that will store to shmem
312+
// even if the mask is false.
313+
for (auto maskOp : peeledMaskOps) {
314+
rewriter.setInsertionPoint(maskOp);
315+
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
316+
Operation *op = &maskOp.getBody()->front();
317+
if (isConstantIntValue(maskOp.getPred(), 0)) {
318+
if (op->getNumResults() > 0) {
319+
SmallVector<Value> results;
320+
for (auto result : op->getResults()) {
321+
auto poisonOp = rewriter.create<mlir::ub::PoisonOp>(
322+
op->getLoc(), result.getType());
323+
results.push_back(poisonOp);
324+
}
325+
op->replaceAllUsesWith(results);
326+
}
327+
op->erase();
328+
}
329+
}
330+
}
331+
332+
SmallVector<ttg::MaskOp> maskOps;
333+
moduleOp->walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); });
334+
for (auto maskOp : maskOps) {
335+
rewriter.setInsertionPoint(maskOp);
336+
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
337+
Operation *op = &maskOp.getBody()->front();
338+
rewriter.moveOpBefore(op, maskOp);
339+
op = triton::predicateOp(rewriter, op, maskOp.getPred());
340+
}
341+
maskOp->replaceAllUsesWith(
342+
maskOp.getBody()->getTerminator()->getOperands());
343+
maskOp->erase();
344+
}
345+
}
346+
282347
// Return true if the given ForOp has the attribute
283348
// `tt.disallow_acc_multi_buffer` set to true.
284349
bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2-
#include "mlir/Dialect/UB/IR/UBOps.h"
32
#include "mlir/IR/TypeUtilities.h"
43
#include "mlir/Interfaces/SideEffectInterfaces.h"
54
#include "mlir/Support/LLVM.h"
@@ -43,67 +42,6 @@ static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) {
4342
}
4443
}
4544

46-
static Operation *wrapInMaskOp(RewriterBase &rewriter, Operation *op,
47-
Value pred) {
48-
auto mask = rewriter.create<MaskOp>(op->getLoc(), op->getResultTypes(), pred);
49-
rewriter.createBlock(&mask->getRegion(0));
50-
rewriter.setInsertionPointToStart(&mask->getRegion(0).front());
51-
auto newOp = rewriter.clone(*op);
52-
rewriter.create<MaskReturnOp>(op->getLoc(), newOp->getResults());
53-
op->replaceAllUsesWith(mask->getResults());
54-
rewriter.eraseOp(op);
55-
return mask;
56-
}
57-
58-
static void resolveMaskOp(ModuleOp moduleOp, DenseSet<MaskOp> &peeledMaskOps) {
59-
IRRewriter rewriter(moduleOp);
60-
61-
// Canonicalize the IR to simplify the arithmetic ops defining the mask
62-
auto arithDialect =
63-
moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>();
64-
RewritePatternSet patterns(moduleOp.getContext());
65-
arithDialect->getCanonicalizationPatterns(patterns);
66-
if (applyPatternsGreedily(moduleOp, std::move(patterns)).failed())
67-
return llvm::report_fatal_error("Failed to canonicalize the IR");
68-
69-
// Prune all the statically dead mask ops in the epilogue. This is a
70-
// hack, ideally we should do it for all the mask ops, but it is incorrect if
71-
// we have speculatively executed async cp operations that will store to shmem
72-
// even if the mask is false.
73-
for (auto maskOp : peeledMaskOps) {
74-
rewriter.setInsertionPoint(maskOp);
75-
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
76-
Operation *op = &maskOp.getBody()->front();
77-
if (isConstantIntValue(maskOp.getPred(), 0)) {
78-
if (op->getNumResults() > 0) {
79-
SmallVector<Value> results;
80-
for (auto result : op->getResults()) {
81-
auto poisonOp =
82-
rewriter.create<ub::PoisonOp>(op->getLoc(), result.getType());
83-
results.push_back(poisonOp);
84-
}
85-
op->replaceAllUsesWith(results);
86-
}
87-
op->erase();
88-
}
89-
}
90-
}
91-
92-
SmallVector<MaskOp> maskOps;
93-
moduleOp->walk([&](MaskOp maskOp) { maskOps.push_back(maskOp); });
94-
for (auto maskOp : maskOps) {
95-
rewriter.setInsertionPoint(maskOp);
96-
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
97-
Operation *op = &maskOp.getBody()->front();
98-
rewriter.moveOpBefore(op, maskOp);
99-
op = triton::predicateOp(rewriter, op, maskOp.getPred());
100-
}
101-
maskOp->replaceAllUsesWith(
102-
maskOp.getBody()->getTerminator()->getOperands());
103-
maskOp->erase();
104-
}
105-
}
106-
10745
static bool hasMMAv5WaitsInLastStage(scf::ForOp forOp,
10846
CoarseSchedule &schedule) {
10947
int maxStage = schedule.getNumStages() - 1;

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
116116
unsigned numWarpGroups = numWarps / 4;
117117
if (numBlocks == 1) {
118118
// Split along the N dimension
119-
sizePerThread = {1, N / (numWarpGroups * 2)};
119+
sizePerThread = {1, ceil<unsigned>(N, numWarpGroups * 2)};
120120
threadsPerWarp = {16, 2};
121121
warpsPerCTA = {4, numWarpGroups};
122122
} else {
123-
sizePerThread = {1, N / 2};
123+
sizePerThread = {1, ceil<unsigned>(N, 2)};
124124
threadsPerWarp = {16, 2};
125125
warpsPerCTA = {0, 0};
126126
// Distribute at most as many warp groups as there is blocks
@@ -138,7 +138,7 @@ Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
138138
warpsPerCTA = {4 * numWarpGroups, 1};
139139
} else {
140140
// Split along N dimension
141-
sizePerThread = {1, N / numWarpGroups};
141+
sizePerThread = {1, ceil<unsigned>(N, numWarpGroups)};
142142
threadsPerWarp = {32, 1};
143143
warpsPerCTA = {4, numWarpGroups};
144144
}
@@ -223,6 +223,22 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
223223
return areLayoutsEquivalent(tensorType.getShape(), layout, enc);
224224
}
225225

226+
LogicalResult impl::verifyMMAv5Op(Operation *op) {
227+
auto isInterleaved = [](MemDescType memdesc) {
228+
auto enc = dyn_cast<TensorMemoryEncodingAttr>(memdesc.getEncoding());
229+
return enc && getTmemAllocSizes(memdesc).numRows != 64 &&
230+
enc.getBlockM() == 64;
231+
};
232+
233+
auto itf = cast<MMAv5OpInterface>(op);
234+
if (isInterleaved(itf.getA().getType()) &&
235+
isInterleaved(itf.getAccumulator().getType())) {
236+
return op->emitOpError(
237+
"does not support blockM=64 with interleaved blocks in TMEM layout");
238+
}
239+
return success();
240+
}
241+
226242
} // namespace nvidia_gpu
227243
} // namespace triton
228244
} // namespace mlir

0 commit comments

Comments
 (0)