Skip to content

Commit c973b79

Browse files
Merge commit 'aac457e8d9af7c17e91c9cdc55a431d029fe8782'
2 parents b0e15fa + aac457e commit c973b79

File tree

6 files changed

+243
-314
lines changed

6 files changed

+243
-314
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 9 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -28,118 +28,18 @@ inline SmallVector<Value> translateTMAIndices(BuilderT &builder, Location loc,
2828
return indices;
2929
}
3030

31-
inline gpu::CTALayoutAttr updateCTALayoutForShape(gpu::CTALayoutAttr ctaLayout,
32-
ArrayRef<int64_t> shape) {
33-
auto rank = shape.size();
34-
if (ctaLayout.getRank() == rank)
35-
return ctaLayout;
31+
gpu::CTALayoutAttr updateCTALayoutForShape(gpu::CTALayoutAttr ctaLayout,
32+
ArrayRef<int64_t> shape);
3633

37-
auto ctx = ctaLayout.getContext();
38-
if (ctaLayout.getRank() > rank) {
39-
unsigned rankDiff = ctaLayout.getRank() - rank;
40-
return gpu::CTALayoutAttr::get(
41-
ctx, ctaLayout.getCTAsPerCGA().drop_front(rankDiff),
42-
ctaLayout.getCTASplitNum().drop_front(rankDiff),
43-
ctaLayout.getCTAOrder().drop_front(rankDiff));
44-
}
45-
// For rank-reducing loads, we need to rank-increase the CTA Layout
46-
auto rankDiff = rank - ctaLayout.getRank();
47-
for (unsigned i = 0; i < rankDiff; ++i) {
48-
assert(shape[i] == 1 && "Should only happen for rank-reducing loads");
49-
}
50-
SmallVector<unsigned> CTAsPerCGA(rank, 1);
51-
SmallVector<unsigned> CTASplitNum(rank, 1);
52-
SmallVector<unsigned> CTAOrder(rank, 1);
53-
54-
llvm::copy(ctaLayout.getCTAsPerCGA(), CTAsPerCGA.begin() + rankDiff);
55-
llvm::copy(ctaLayout.getCTASplitNum(), CTASplitNum.begin() + rankDiff);
56-
for (unsigned i = 0; i < rankDiff; ++i) {
57-
CTAOrder[i] = rank - i;
58-
}
59-
llvm::copy(ctaLayout.getCTAOrder(), CTAOrder.begin() + rankDiff);
60-
return gpu::CTALayoutAttr::get(ctx, CTAsPerCGA, CTASplitNum, CTAOrder);
61-
}
62-
63-
inline gpu::SharedEncodingTrait
34+
gpu::SharedEncodingTrait
6435
updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding,
65-
RankedTensorType tensorType) {
66-
auto ctx = encoding.getContext();
67-
auto ctaLayout = gpu::getCTALayout(encoding);
68-
if (auto nvmmaEnc = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding)) {
69-
auto existingCta = nvmmaEnc.getCTALayout();
70-
if (!existingCta)
71-
return nvmmaEnc;
72-
73-
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
74-
return gpu::NVMMASharedEncodingAttr::get(
75-
ctx, nvmmaEnc.getSwizzlingByteWidth(), nvmmaEnc.getTransposed(),
76-
nvmmaEnc.getElementBitWidth(), nvmmaEnc.getFp4Padded(), newCtaEnc);
77-
}
78-
if (auto swizEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding)) {
79-
auto existingCta = swizEnc.getCTALayout();
80-
if (!existingCta)
81-
return swizEnc;
82-
83-
auto rank = tensorType.getRank();
84-
auto oldOrder = swizEnc.getOrder();
85-
SmallVector<unsigned> order;
86-
for (int i = 0; i + oldOrder.size() < rank; ++i)
87-
order.push_back(rank - i - 1);
88-
for (int i = 0; i < oldOrder.size(); ++i) {
89-
// If it is a rank-reducing load, we need to drop the last dimensions.
90-
if (oldOrder[i] >= rank)
91-
continue;
92-
order.push_back(oldOrder[i]);
93-
}
94-
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
95-
return gpu::SwizzledSharedEncodingAttr::get(
96-
ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),
97-
order, newCtaEnc);
98-
}
99-
100-
constexpr auto msg = "Internal Error: Unhandled tensor descriptor encoding";
101-
if (op)
102-
op->emitError() << msg;
103-
llvm::report_fatal_error(msg);
104-
}
36+
RankedTensorType tensorType);
10537

106-
inline triton::gpu::SharedEncodingTrait
38+
triton::gpu::SharedEncodingTrait
10739
getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,
108-
Value desc) {
109-
auto descBlockType = cast<TensorDescType>(desc.getType()).getBlockType();
110-
Attribute encoding = descBlockType.getEncoding();
111-
if (!encoding) {
112-
constexpr auto msg =
113-
"Internal Error: Tensor descriptor should have encoding set";
114-
if (op)
115-
op->emitError() << msg;
116-
llvm::report_fatal_error(msg);
117-
}
118-
auto sharedEnc = cast<gpu::SharedEncodingTrait>(encoding);
119-
if (descBlockType.getShape() == tensorType.getShape())
120-
return sharedEnc;
121-
122-
return updateEncodingForShape(op, sharedEnc, tensorType);
123-
}
40+
Value desc);
12441

125-
inline int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape) {
126-
assert(encoding);
127-
auto mmaEncoding =
128-
llvm::dyn_cast_or_null<gpu::NVMMASharedEncodingAttr>(encoding);
129-
130-
// The bounding box inner dimension must be less than or equal to the
131-
// swizzle size.
132-
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
133-
// We clamp the block size and the codegen will emit multiple copy
134-
// operations.
135-
if (mmaEncoding) {
136-
auto elemSize = mmaEncoding.getElementBitWidth() / 8;
137-
return mmaEncoding.getSwizzlingByteWidth() / elemSize;
138-
}
139-
140-
auto shapePerCTA = gpu::getShapePerCTA(encoding, shape);
141-
return shapePerCTA.back();
142-
}
42+
int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape);
14343

14444
inline int64_t getTMAContigDim(RankedTensorType tensorType) {
14545
return getTMAContigDim(tensorType.getEncoding(), tensorType.getShape());
@@ -149,61 +49,9 @@ inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
14949
return getTMAContigDim(memDescType.getEncoding(), memDescType.getShape());
15050
}
15151

152-
inline std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty) {
153-
auto encoding = ty.getBlockType().getEncoding();
154-
auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
155-
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
156-
if (!mmaEncoding) {
157-
auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(encoding);
158-
if (!swizzledEnc || swizzledEnc.getVec() != 1 ||
159-
swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) {
160-
if (op)
161-
op->emitError("Unhandled encoding type");
162-
return std::nullopt;
163-
}
164-
}
165-
166-
bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded();
167-
assert(!fp4Padded || swizzleBytes == 128 &&
168-
"elem type .b4x16_p64 supports only 128B swizzling");
52+
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty);
16953

170-
int32_t swizzleMode = 0;
171-
if (swizzleBytes == 128) {
172-
swizzleMode = 3;
173-
} else if (swizzleBytes == 64) {
174-
swizzleMode = 2;
175-
} else if (swizzleBytes == 32) {
176-
swizzleMode = 1;
177-
}
178-
return swizzleMode;
179-
}
180-
181-
inline std::optional<int> getTMAElementType(Operation *op, TensorDescType ty) {
182-
auto encoding = ty.getBlockType().getEncoding();
183-
auto mmaEncoding = dyn_cast<gpu::NVMMASharedEncodingAttr>(encoding);
184-
bool fp4Padded = mmaEncoding && mmaEncoding.getFp4Padded();
185-
186-
if (fp4Padded)
187-
return 14; // .b4x16_p64
188-
189-
auto elemSize = ty.getBlockType().getElementTypeBitWidth() / 8;
190-
switch (elemSize) {
191-
case 1:
192-
return 0;
193-
case 2:
194-
return 1;
195-
case 4:
196-
return 2;
197-
default:
198-
break;
199-
}
200-
if (op) {
201-
op->emitError()
202-
<< "Tensor descriptor element type must have size 1, 2, or 4 but got "
203-
<< elemSize;
204-
}
205-
return std::nullopt;
206-
}
54+
std::optional<int> getTMAElementType(Operation *op, TensorDescType ty);
20755

20856
template <typename BuilderT>
20957
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_triton_library(TritonNvidiaGPUTransforms
77
PromoteLHSToTMem.cpp
88
TensorMemoryAllocation.cpp
99
TMALowering.cpp
10+
TMAUtilities.cpp
1011

1112
DEPENDS
1213
TritonNvidiaGPUTransformsIncGen
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#include <triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h>
2+
3+
namespace tt = mlir::triton;
4+
namespace ttg = mlir::triton::gpu;
5+
6+
namespace mlir::triton::nvidia_gpu {
7+
8+
ttg::CTALayoutAttr updateCTALayoutForShape(ttg::CTALayoutAttr ctaLayout,
9+
ArrayRef<int64_t> shape) {
10+
auto rank = shape.size();
11+
if (ctaLayout.getRank() == rank)
12+
return ctaLayout;
13+
14+
auto ctx = ctaLayout.getContext();
15+
if (ctaLayout.getRank() > rank) {
16+
unsigned rankDiff = ctaLayout.getRank() - rank;
17+
return ttg::CTALayoutAttr::get(
18+
ctx, ctaLayout.getCTAsPerCGA().drop_front(rankDiff),
19+
ctaLayout.getCTASplitNum().drop_front(rankDiff),
20+
ctaLayout.getCTAOrder().drop_front(rankDiff));
21+
}
22+
// For rank-reducing loads, we need to rank-increase the CTA Layout
23+
auto rankDiff = rank - ctaLayout.getRank();
24+
for (unsigned i = 0; i < rankDiff; ++i) {
25+
assert(shape[i] == 1 && "Should only happen for rank-reducing loads");
26+
}
27+
SmallVector<unsigned> CTAsPerCGA(rank, 1);
28+
SmallVector<unsigned> CTASplitNum(rank, 1);
29+
SmallVector<unsigned> CTAOrder(rank, 1);
30+
31+
llvm::copy(ctaLayout.getCTAsPerCGA(), CTAsPerCGA.begin() + rankDiff);
32+
llvm::copy(ctaLayout.getCTASplitNum(), CTASplitNum.begin() + rankDiff);
33+
for (unsigned i = 0; i < rankDiff; ++i) {
34+
CTAOrder[i] = rank - i;
35+
}
36+
llvm::copy(ctaLayout.getCTAOrder(), CTAOrder.begin() + rankDiff);
37+
return ttg::CTALayoutAttr::get(ctx, CTAsPerCGA, CTASplitNum, CTAOrder);
38+
}
39+
40+
ttg::SharedEncodingTrait
41+
updateEncodingForShape(Operation *op, ttg::SharedEncodingTrait encoding,
42+
RankedTensorType tensorType) {
43+
auto ctx = encoding.getContext();
44+
auto ctaLayout = ttg::getCTALayout(encoding);
45+
if (auto nvmmaEnc = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding)) {
46+
auto existingCta = nvmmaEnc.getCTALayout();
47+
if (!existingCta)
48+
return nvmmaEnc;
49+
50+
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
51+
return ttg::NVMMASharedEncodingAttr::get(
52+
ctx, nvmmaEnc.getSwizzlingByteWidth(), nvmmaEnc.getTransposed(),
53+
nvmmaEnc.getElementBitWidth(), nvmmaEnc.getFp4Padded(), newCtaEnc);
54+
}
55+
if (auto swizEnc = dyn_cast<ttg::SwizzledSharedEncodingAttr>(encoding)) {
56+
auto existingCta = swizEnc.getCTALayout();
57+
if (!existingCta)
58+
return swizEnc;
59+
60+
auto rank = tensorType.getRank();
61+
auto oldOrder = swizEnc.getOrder();
62+
SmallVector<unsigned> order;
63+
for (int i = 0; i + oldOrder.size() < rank; ++i)
64+
order.push_back(rank - i - 1);
65+
for (int i = 0; i < oldOrder.size(); ++i) {
66+
// If it is a rank-reducing load, we need to drop the last dimensions.
67+
if (oldOrder[i] >= rank)
68+
continue;
69+
order.push_back(oldOrder[i]);
70+
}
71+
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
72+
return ttg::SwizzledSharedEncodingAttr::get(
73+
ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),
74+
order, newCtaEnc);
75+
}
76+
77+
constexpr auto msg = "Internal Error: Unhandled tensor descriptor encoding";
78+
if (op)
79+
op->emitError() << msg;
80+
llvm::report_fatal_error(msg);
81+
}
82+
83+
ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
84+
RankedTensorType tensorType,
85+
Value desc) {
86+
auto descBlockType = cast<TensorDescType>(desc.getType()).getBlockType();
87+
Attribute encoding = descBlockType.getEncoding();
88+
if (!encoding) {
89+
constexpr auto msg =
90+
"Internal Error: Tensor descriptor should have encoding set";
91+
if (op)
92+
op->emitError() << msg;
93+
llvm::report_fatal_error(msg);
94+
}
95+
auto sharedEnc = cast<ttg::SharedEncodingTrait>(encoding);
96+
if (descBlockType.getShape() == tensorType.getShape())
97+
return sharedEnc;
98+
99+
return updateEncodingForShape(op, sharedEnc, tensorType);
100+
}
101+
102+
int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape) {
103+
assert(encoding);
104+
auto mmaEncoding =
105+
llvm::dyn_cast_or_null<ttg::NVMMASharedEncodingAttr>(encoding);
106+
107+
// The bounding box inner dimension must be less than or equal to the
108+
// swizzle size.
109+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
110+
// We clamp the block size and the codegen will emit multiple copy
111+
// operations.
112+
if (mmaEncoding) {
113+
auto elemSize = mmaEncoding.getElementBitWidth() / 8;
114+
return mmaEncoding.getSwizzlingByteWidth() / elemSize;
115+
}
116+
117+
auto shapePerCTA = ttg::getShapePerCTA(encoding, shape);
118+
return shapePerCTA.back();
119+
}
120+
121+
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty) {
122+
auto encoding = ty.getBlockType().getEncoding();
123+
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
124+
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
125+
if (!mmaEncoding) {
126+
auto swizzledEnc = dyn_cast<ttg::SwizzledSharedEncodingAttr>(encoding);
127+
if (!swizzledEnc || swizzledEnc.getVec() != 1 ||
128+
swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) {
129+
if (op)
130+
op->emitError("Unhandled encoding type");
131+
return std::nullopt;
132+
}
133+
}
134+
135+
bool fp4Padded = isFp4Padded(encoding);
136+
assert(!fp4Padded || swizzleBytes == 128 &&
137+
"elem type .b4x16_p64 supports only 128B swizzling");
138+
139+
int32_t swizzleMode = 0;
140+
if (swizzleBytes == 128) {
141+
swizzleMode = 3;
142+
} else if (swizzleBytes == 64) {
143+
swizzleMode = 2;
144+
} else if (swizzleBytes == 32) {
145+
swizzleMode = 1;
146+
}
147+
return swizzleMode;
148+
}
149+
150+
std::optional<int> getTMAElementType(Operation *op, TensorDescType ty) {
151+
auto encoding = ty.getBlockType().getEncoding();
152+
auto mmaEncoding = dyn_cast<ttg::NVMMASharedEncodingAttr>(encoding);
153+
bool fp4Padded = isFp4Padded(encoding);
154+
155+
if (fp4Padded)
156+
return 14; // .b4x16_p64
157+
158+
auto elemSize = ty.getBlockType().getElementTypeBitWidth() / 8;
159+
switch (elemSize) {
160+
case 1:
161+
return 0;
162+
case 2:
163+
return 1;
164+
case 4:
165+
return 2;
166+
default:
167+
break;
168+
}
169+
if (op) {
170+
op->emitError()
171+
<< "Tensor descriptor element type must have size 1, 2, or 4 but got "
172+
<< elemSize;
173+
}
174+
return std::nullopt;
175+
}
176+
177+
} // namespace mlir::triton::nvidia_gpu

python/triton/tools/experimental_descriptor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,11 @@ class TensorDescriptor:
4646
shape: List[int]
4747
strides: List[int]
4848
block_shape: List[int]
49+
50+
def from_tensor(tensor: Any, block_shape: List[int]):
51+
return TensorDescriptor(
52+
tensor,
53+
tensor.shape,
54+
tensor.stride(),
55+
block_shape,
56+
)

0 commit comments

Comments
 (0)