Skip to content

Commit 3887b80

Browse files
authored
[TMA] Enable unswizzled tma layouts (#6238)
This enables tma support for smaller loaded blocks by falling back to unswizzled encodings where necessary. We are also careful to propagate the shape info from gather/scatter instructions so these can still enable swizzling where possible.
1 parent 188048c commit 3887b80

File tree

8 files changed

+141
-60
lines changed

8 files changed

+141
-60
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def NVMMASharedEncodingAttr :
430430
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
431431
swizzlingByteWidth = 32;
432432
} else {
433-
llvm_unreachable("unsupported shared memory layout for MMAv3");
433+
llvm_unreachable("unsupported NVMMA layout (MMAv3 or TMA)");
434434
}
435435
bool transposed = order[0] == 0;
436436
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout);

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,16 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
156156
"elem type .b4x16_p64 supports only 128B swizzling");
157157
}
158158
} else {
159-
op->emitError() << "Unhandled encoding type";
160-
return failure();
159+
auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(
160+
op.getType().getBlockType().getEncoding());
161+
if (!swizzledEnc || swizzledEnc.getVec() != 1 ||
162+
swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) {
163+
op->emitError() << "Unhandled encoding type";
164+
return failure();
165+
}
161166
}
162167

163-
int32_t swizzle_mode;
168+
int32_t swizzle_mode = 0;
164169
if (swizzleBytes == 128) {
165170
swizzle_mode = 3;
166171
} else if (swizzleBytes == 64) {

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
#include "mlir/IR/TypeUtilities.h"
22
#include "mlir/Pass/PassManager.h"
3-
#include "mlir/Transforms/Passes.h"
4-
#include "triton/Analysis/AxisInfo.h"
53
#include "triton/Dialect/Triton/IR/Dialect.h"
64
#include "triton/Dialect/Triton/IR/Types.h"
75
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
86
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
97
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
10-
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
118
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
129
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1310
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
1411
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
15-
#include "triton/Tools/Sys/GetEnv.hpp"
1612
#include "llvm/ADT/PriorityWorklist.h"
17-
#include "llvm/ADT/Sequence.h"
18-
#include "llvm/Support/Casting.h"
19-
#include "llvm/Support/VersionTuple.h"
13+
#include <algorithm>
2014
#include <memory>
2115
#include <unordered_set>
2216

@@ -35,6 +29,7 @@ struct UseInfo {
3529
TypedValue<tt::TensorDescType> descriptor;
3630
Operation *use;
3731
Attribute desiredSharedEncoding;
32+
SmallVector<int64_t> shape;
3833
ttg::CTALayoutAttr ctaLayout;
3934
};
4035

@@ -72,6 +67,14 @@ ttg::CTALayoutAttr getCtaLayoutFromEncoding(Attribute encoding) {
7267
layout.getCTASplitNum(), layout.getCTAOrder());
7368
}
7469

70+
SmallVector<int64_t> expandToRank(ArrayRef<int64_t> shape, int rank) {
71+
SmallVector<int64_t> result(rank, 1);
72+
assert(shape.size() <= rank);
73+
auto rankDiff = rank - shape.size();
74+
std::copy(shape.begin(), shape.end(), result.begin() + rankDiff);
75+
return result;
76+
}
77+
7578
std::optional<UseInfo> getUseInfo(Operation *op) {
7679
UseInfo info;
7780
info.use = op;
@@ -81,6 +84,9 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8184
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
8285
: load.getType().getEncoding();
8386
info.ctaLayout = ttg::getCTALayout(encoding);
87+
auto shape = load.getResult().getType().getShape();
88+
auto rank = load.getDesc().getType().getBlockType().getRank();
89+
info.shape = expandToRank(shape, rank);
8490
return info;
8591
}
8692
if (auto gather = dyn_cast<tt::DescriptorGatherOp>(op)) {
@@ -89,18 +95,27 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8995
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
9096
: gather.getType().getEncoding();
9197
info.ctaLayout = ttg::getCTALayout(encoding);
98+
auto shape = gather.getResult().getType().getShape();
99+
auto rank = gather.getDesc().getType().getBlockType().getRank();
100+
info.shape = expandToRank(shape, rank);
92101
return info;
93102
}
94103
if (auto store = dyn_cast<tt::DescriptorStoreOp>(op)) {
95104
info.descriptor = store.getDesc();
96105
auto encoding = store.getSrc().getType().getEncoding();
97106
info.ctaLayout = ttg::getCTALayout(encoding);
107+
auto shape = store.getSrc().getType().getShape();
108+
auto rank = store.getDesc().getType().getBlockType().getRank();
109+
info.shape = expandToRank(shape, rank);
98110
return info;
99111
}
100112
if (auto scatter = dyn_cast<tt::DescriptorScatterOp>(op)) {
101113
info.descriptor = scatter.getDesc();
102114
auto encoding = scatter.getSrc().getType().getEncoding();
103115
info.ctaLayout = ttg::getCTALayout(encoding);
116+
auto shape = scatter.getSrc().getType().getShape();
117+
auto rank = scatter.getDesc().getType().getBlockType().getRank();
118+
info.shape = expandToRank(shape, rank);
104119
return info;
105120
}
106121
return std::nullopt;
@@ -109,12 +124,15 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
109124
struct EncodingInfo {
110125
Attribute desiredEncoding;
111126
ttg::CTALayoutAttr ctaLayout;
127+
// Shape may be different from the descriptor block shape for gather/scatter
128+
// use case
129+
SmallVector<int64_t> shape;
112130
bool forcedToDefault = false;
113131

114132
bool operator==(const EncodingInfo &other) const {
115133
return desiredEncoding == other.desiredEncoding &&
116134
ctaLayout == other.ctaLayout &&
117-
forcedToDefault == other.forcedToDefault;
135+
forcedToDefault == other.forcedToDefault && shape == other.shape;
118136
}
119137
};
120138

@@ -123,7 +141,8 @@ struct EncodingInfo {
123141
template <> struct std::hash<EncodingInfo> {
124142
size_t operator()(const EncodingInfo &einfo) const {
125143
return llvm::hash_combine(einfo.desiredEncoding, einfo.ctaLayout,
126-
einfo.forcedToDefault);
144+
einfo.forcedToDefault,
145+
ArrayRef<int64_t>(einfo.shape));
127146
}
128147
};
129148

@@ -172,6 +191,21 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
172191
// Always propagate forcedToDefault
173192
result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault;
174193

194+
if (result.forcedToDefault)
195+
return result;
196+
197+
if (lhs.shape.empty() || lhs.shape == rhs.shape)
198+
result.shape = rhs.shape;
199+
else if (rhs.shape.empty())
200+
result.shape = lhs.shape;
201+
else {
202+
assert(lhs.shape.size() == rhs.shape.size());
203+
auto rank = lhs.shape.size();
204+
result.shape.reserve(rank);
205+
for (int i = 0; i < rank; ++i)
206+
result.shape.push_back(std::min(lhs.shape[i], rhs.shape[i]));
207+
}
208+
175209
SetVector<ttg::CTALayoutAttr> ctaLayouts;
176210
if (lhs.ctaLayout)
177211
ctaLayouts.insert(lhs.ctaLayout);
@@ -190,9 +224,6 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
190224
break;
191225
}
192226

193-
if (result.forcedToDefault)
194-
return result;
195-
196227
SetVector<Attribute> desiredEncodings;
197228
if (lhs.desiredEncoding)
198229
desiredEncodings.insert(lhs.desiredEncoding);
@@ -213,23 +244,32 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
213244
}
214245

215246
Attribute getFallbackSharedEncoding(RankedTensorType tensorType,
216-
ttg::CTALayoutAttr ctaLayout) {
247+
ttg::CTALayoutAttr ctaLayout,
248+
ArrayRef<int64_t> usageShape) {
217249
auto ctx = tensorType.getContext();
218250
SmallVector<unsigned> order;
219251
for (int i = tensorType.getRank() - 1; i >= 0; --i)
220252
order.push_back(i);
221253

254+
ArrayRef<int64_t> shape =
255+
usageShape.empty() ? tensorType.getShape() : usageShape;
222256
if (!ctaLayout)
223257
ctaLayout = ttg::CTALayoutAttr::getDefault(ctx, tensorType.getRank());
224258
else if (ctaLayout.getRank() != tensorType.getRank())
225-
ctaLayout = ttng::updateCTALayoutForShape(ctaLayout, tensorType.getShape());
259+
ctaLayout = ttng::updateCTALayoutForShape(ctaLayout, shape);
260+
261+
auto elemTy = tensorType.getElementType();
262+
auto shapePerCTA = ttg::getShapePerCTA(ctaLayout.getCTASplitNum(), shape);
263+
unsigned eleBitWidth = tensorType.getElementType().getIntOrFloatBitWidth();
226264

227-
if (tensorType.getRank() == 1) {
265+
auto contigDimSizeInBytes = shapePerCTA.back() * eleBitWidth / 8;
266+
auto rank = tensorType.getRank();
267+
if (rank == 1 || contigDimSizeInBytes < 32 || shape[rank - 2] < 8) {
228268
return ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, order, ctaLayout);
229269
}
230-
return ttg::NVMMASharedEncodingAttr::get(
231-
ctx, tensorType.getShape(), order, ctaLayout, tensorType.getElementType(),
232-
/*fp4Padded*/ false);
270+
return ttg::NVMMASharedEncodingAttr::get(ctx, shape, order, ctaLayout,
271+
tensorType.getElementType(),
272+
/*fp4Padded*/ false);
233273
}
234274

235275
tt::TensorDescType getTensorDescTypeWithEncoding(Operation *op,
@@ -274,17 +314,19 @@ void assignMemoryLayouts(tt::FuncOp &func) {
274314
// fallback to default encoding
275315
for (auto blockArg : func.getBlocks().front().getArguments())
276316
if (auto desc = dyn_cast<TypedValue<tt::TensorDescType>>(blockArg))
277-
updateEncoding({desc}, EncodingInfo{{}, {}, /*forcedToDefault=*/true});
317+
updateEncoding({desc},
318+
EncodingInfo{{}, {}, {}, /*forcedToDefault=*/true});
278319

279320
func.walk([&](Operation *op) {
280321
if (auto info = getUseInfo(op)) {
281-
updateEncoding(info->descriptor, EncodingInfo{info->desiredSharedEncoding,
282-
info->ctaLayout});
322+
updateEncoding(info->descriptor,
323+
EncodingInfo{info->desiredSharedEncoding, info->ctaLayout,
324+
info->shape});
283325
} else {
284326
bool forcedToDefault =
285327
isa<tt::CallOp, tt::ReturnOp, tt::ReinterpretTensorDescOp>(op);
286328
auto einfo =
287-
internEncoding(encodings, EncodingInfo{{}, {}, forcedToDefault});
329+
internEncoding(encodings, EncodingInfo{{}, {}, {}, forcedToDefault});
288330

289331
auto setEncoding = [&](Value v) {
290332
auto typedVal = cast<TypedValue<tt::TensorDescType>>(v);
@@ -344,9 +386,10 @@ void assignMemoryLayouts(tt::FuncOp &func) {
344386
if (einfo->desiredEncoding) {
345387
newEncoding = einfo->desiredEncoding;
346388
} else if (einfo->forcedToDefault) {
347-
newEncoding = getFallbackSharedEncoding(existingTy, {});
389+
newEncoding = getFallbackSharedEncoding(existingTy, {}, {});
348390
} else {
349-
newEncoding = getFallbackSharedEncoding(existingTy, einfo->ctaLayout);
391+
newEncoding =
392+
getFallbackSharedEncoding(existingTy, einfo->ctaLayout, einfo->shape);
350393
}
351394
desc.setType(getTensorDescTypeWithEncoding(desc.getDefiningOp(), existingTy,
352395
newEncoding));
@@ -356,14 +399,14 @@ void assignMemoryLayouts(tt::FuncOp &func) {
356399
SmallVector<Type> resultTys(func.getResultTypes());
357400
for (auto [i, argTy] : llvm::enumerate(argTys)) {
358401
if (auto descTy = dyn_cast<tt::TensorDescType>(argTy)) {
359-
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {});
402+
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {}, {});
360403
argTys[i] = getTensorDescTypeWithEncoding(nullptr, descTy.getBlockType(),
361404
encoding);
362405
}
363406
}
364407
for (auto [i, resultTy] : llvm::enumerate(resultTys)) {
365408
if (auto descTy = dyn_cast<tt::TensorDescType>(resultTy)) {
366-
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {});
409+
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {}, {});
367410
resultTys[i] = getTensorDescTypeWithEncoding(
368411
nullptr, descTy.getBlockType(), encoding);
369412
}

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
@requires_tma
1212
@pytest.mark.interpreter
1313
@pytest.mark.parametrize("dtype_str", tma_dtypes)
14-
def test_tensor_descriptor_load(dtype_str):
14+
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32)])
15+
def test_tensor_descriptor_load(dtype_str, M_BLOCK, N_BLOCK):
1516

1617
@triton.jit
1718
def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
@@ -41,9 +42,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
4142

4243
M, N = 32, 128
4344
inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str)
44-
45-
M_BLOCK = 8
46-
N_BLOCK = 32
4745
out = inp.new_empty((M_BLOCK, N_BLOCK))
4846

4947
kernel[(1, )](out, inp, M, N, M_BLOCK, N_BLOCK)
@@ -55,7 +53,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
5553
@requires_tma
5654
@pytest.mark.interpreter
5755
@pytest.mark.parametrize("dtype_str", tma_dtypes)
58-
def test_tensor_descriptor_store(dtype_str):
56+
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32)])
57+
def test_tensor_descriptor_store(dtype_str, M_BLOCK, N_BLOCK):
5958

6059
@triton.jit
6160
def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
@@ -84,9 +83,6 @@ def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
8483

8584
M, N = 32, 128
8685
inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str)
87-
88-
M_BLOCK = 8
89-
N_BLOCK = 32
9086
out = inp.new_empty((M, N))
9187

9288
grid_m = M // M_BLOCK

python/test/unit/language/test_core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,31 +4974,31 @@ def test_tma_load_block_shape_err(device):
49744974

49754975
@triton.jit
49764976
def kernel(ptr):
4977-
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32])
4977+
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 2])
49784978
desc.load([0, 0])
49794979

49804980
input = torch.empty((128, 128), dtype=torch.int32, device=device)
49814981
errc = triton.CompilationError if not is_interpreter() else InterpreterError
49824982
with pytest.raises(errc) as e:
49834983
kernel[(1, )](input)
49844984

4985-
assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__)
4985+
assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__)
49864986

49874987

49884988
@pytest.mark.interpreter
49894989
def test_tma_store_block_shape_err(device):
49904990

49914991
@triton.jit
49924992
def kernel(ptr):
4993-
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8])
4994-
desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16))
4993+
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 4])
4994+
desc.store([0, 0], tl.zeros([8, 4], dtype=tl.int16))
49954995

49964996
input = torch.empty((128, 128), dtype=torch.int16, device=device)
49974997
errc = triton.CompilationError if not is_interpreter() else InterpreterError
49984998
with pytest.raises(errc) as e:
49994999
kernel[(1, )](input)
50005000

5001-
assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__)
5001+
assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__)
50025002

50035003

50045004
def test_trans_reshape(device):

python/triton/language/semantic.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,21 +1152,9 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type,
11521152
return tl.tensor_descriptor_base(handle, block_ty)
11531153

11541154

1155-
def validate_descriptor_block(shape, dtype):
1156-
if len(shape) != 2:
1157-
return
1158-
# Due to limitations of the shared memory encoding, the TMA bounding box has
1159-
# to be at least as big as the swizzle tile.
1160-
assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
1161-
min_cols = 32 // dtype.primitive_bitwidth * 8
1162-
assert shape[
1163-
1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"
1164-
1165-
11661155
def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
11671156
builder: ir.builder) -> tl.tensor:
11681157
assert isinstance(desc, tl.tensor_descriptor_base)
1169-
validate_descriptor_block(desc.block_shape, desc.dtype)
11701158
ndim = len(desc.block_shape)
11711159
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
11721160

@@ -1178,7 +1166,6 @@ def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache
11781166

11791167
def descriptor_store(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
11801168
assert isinstance(desc, tl.tensor_descriptor_base)
1181-
validate_descriptor_block(desc.block_shape, desc.dtype)
11821169
ndim = len(desc.block_shape)
11831170
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
11841171
assert value.shape == desc.block_shape
@@ -1931,6 +1918,13 @@ def make_tensor_descriptor(
19311918
raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
19321919
if len(block_shape) != ndim:
19331920
raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
1921+
assert isinstance(base.dtype, tl.pointer_type)
1922+
elem_size = base.dtype.element_ty.primitive_bitwidth // 8
1923+
contig_dim_size = tl._constexpr_to_value(block_shape[-1])
1924+
if contig_dim_size * elem_size < 16:
1925+
raise ValueError(
1926+
f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
1927+
)
19341928

19351929
strides[-1] = tl._constexpr_to_value(strides[-1])
19361930
if strides[-1] != 1:

0 commit comments

Comments
 (0)