Skip to content

Commit 99506b7

Browse files
authored
[Pipeliner] Multi-buffer TMA descriptors (#5290)
### Commits in this PR 1. [Pipeliner] Multi-buffer TMA descriptors 2. Add tests for pipelined descriptor creation 3. Be more conservative about number of TMA buffers to allocate 4. Update golden samples 5. Use correct modulus for tma updates
1 parent 2ec1d17 commit 99506b7

File tree

11 files changed

+960
-155
lines changed

11 files changed

+960
-155
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_
33

44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include <optional>
6+
#include <utility>
57
#include <vector>
68

79
namespace mlir {
@@ -38,6 +40,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3840
// Return the minClusterId and maxClusterId for the given ForOp.
3941
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
4042
std::pair<int, int> getStageCluster(Operation *op);
43+
std::optional<std::pair<int, int>> maybeGetStageCluster(Operation *op);
4144
void setStageCluster(Operation *op, int stage, int cluster);
4245
} // namespace triton
4346
} // namespace mlir
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#pragma once
2+
#include "mlir/IR/BuiltinTypes.h"
3+
#include "mlir/IR/PatternMatch.h"
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
6+
namespace mlir::triton::nvidia_gpu {
7+
8+
constexpr inline int TMA_SIZE_BYTES = 128;
9+
constexpr inline int TMA_ALIGN = 128;
10+
11+
template <typename BuilderT>
12+
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
13+
mlir::triton::MakeTensorDescOp op,
14+
BuilderT &builder) {
15+
using namespace mlir;
16+
MLIRContext *ctx = op.getContext();
17+
auto loc = op.getLoc();
18+
auto mkI32Constant = [&](int32_t val) {
19+
return builder.template create<arith::ConstantOp>(
20+
loc, builder.getI32Type(), builder.getI32IntegerAttr(val));
21+
};
22+
23+
auto elemType = op.getBase().getType().getPointeeType();
24+
auto elemSize = elemType.getIntOrFloatBitWidth() / 8;
25+
26+
int32_t contig_dim_size = op.getTensorShape().back();
27+
int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize;
28+
if (contig_dim_size_in_bytes > 128) {
29+
contig_dim_size = 128 / elemSize;
30+
}
31+
llvm::SmallVector<Value> boxDim;
32+
boxDim.push_back(mkI32Constant(contig_dim_size));
33+
for (int k = op.getTensorShape().size() - 2; k >= 0; --k) {
34+
boxDim.push_back(mkI32Constant(op.getTensorShape()[k]));
35+
}
36+
37+
int32_t swizzle_mode;
38+
if (contig_dim_size_in_bytes >= 128) {
39+
swizzle_mode = 3;
40+
} else if (contig_dim_size_in_bytes == 64) {
41+
swizzle_mode = 2;
42+
} else if (contig_dim_size_in_bytes == 32) {
43+
swizzle_mode = 1;
44+
} else {
45+
op->emitError()
46+
<< "contiguous box dimension must be at least 32 bytes but got "
47+
<< contig_dim_size_in_bytes;
48+
return failure();
49+
}
50+
51+
Value elemSizeVal = builder.template create<arith::ConstantOp>(
52+
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
53+
Value globalStride = builder.template create<arith::MulIOp>(
54+
loc, op.getStrides()[0], elemSizeVal);
55+
// TODO: Workaround for ptxas bug, remove when we update ptxas
56+
Value four = builder.template create<arith::ConstantOp>(
57+
loc, builder.getI64Type(), builder.getI64IntegerAttr(4));
58+
globalStride =
59+
builder.template create<arith::ShRSIOp>(loc, globalStride, four);
60+
61+
int elemTypeEnum;
62+
switch (elemSize) {
63+
case 1: {
64+
elemTypeEnum = 0;
65+
break;
66+
}
67+
case 2: {
68+
elemTypeEnum = 1;
69+
break;
70+
}
71+
case 4: {
72+
elemTypeEnum = 2;
73+
break;
74+
}
75+
default: {
76+
op->emitError()
77+
<< "Tensor descriptor element type must have size 1, 2, or 4 but got "
78+
<< elemSize;
79+
return failure();
80+
}
81+
}
82+
83+
auto one = mkI32Constant(1);
84+
builder.template create<triton::ExperimentalTensormapCreateOp>(
85+
loc,
86+
/*desc_ptr=*/tmaPtr,
87+
/*global_address=*/op.getBase(),
88+
/*box_dim=*/boxDim,
89+
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
90+
/*global_stride=*/ValueRange{globalStride},
91+
/*element_strides=*/ValueRange{one, one},
92+
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
93+
/*interleave_layout*/ builder.getI32IntegerAttr(0),
94+
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),
95+
/*fill_mode=*/builder.getI32IntegerAttr(0));
96+
return success();
97+
}
98+
99+
} // namespace mlir::triton::nvidia_gpu

0 commit comments

Comments
 (0)