Skip to content

Commit 6e4647e

Browse files
authored
[BACKEND] Lower tcgen05.cp via the generic matrix descriptor lowering (#8338)
This also fixes a bug found by @masahi, but this uncovered a PTX bug, so we stay as we were on that front lol
1 parent aafec41 commit 6e4647e

File tree

5 files changed

+154
-290
lines changed

5 files changed

+154
-290
lines changed

python/test/gluon/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,6 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
692692
tmem_alias: ttgl.constexpr = TensorMemoryLayout((num_rows, num_cols), col_stride=1)
693693
tmem = tmem._reinterpret(ttgl.int8, (num_rows, num_cols), tmem_alias)
694694
value = tmem.load(blocked)
695-
ttgl.static_print(ttgl.to_linear_layout(blocked, (smem_h, smem_w)))
696-
ttgl.static_print(ttgl.to_linear_layout(blocked, (num_rows, num_cols)))
697695
ttgl.store(ttgl.set_auto_layout(out_ptrs, blocked), value)
698696

699697
torch.manual_seed(0)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 88 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ struct MMASMEMDescriptor {
2828
SMEMDescriptor descriptor;
2929
int32_t swizzlingByteWidth;
3030
int32_t bitwidth;
31-
bool twoCTAs;
3231
bool transposed;
3332
bool fp4Padded;
3433
};
@@ -53,77 +52,67 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
5352

5453
DotOpMmaSmemLoader(MMASMEMDescriptor desc, Value baseb128, LinearLayout llInv,
5554
ArrayRef<unsigned> instrShape)
56-
: desc(desc), baseb128(baseb128), llInv(std::move(llInv)),
55+
: desc(desc), baseb128(baseb128), ll(std::move(llInv)),
5756
instrShape(instrShape) {}
5857

5958
static DotOpMmaSmemLoader
60-
build(Location loc, RewriterBase &rewriter, triton::gpu::MemDescType tensor,
59+
build(Location loc, RewriterBase &rewriter, gpu::MemDescType memTy,
6160
Value smemBase, ArrayRef<unsigned> instrShape, int mmaVersion,
61+
bool isFp4 = false,
6262
std::optional<RankedTensorType> mmaTy = std::nullopt,
6363
std::optional<unsigned> MNdim = std::nullopt) {
64-
auto ctx = tensor.getContext();
64+
auto ctx = rewriter.getContext();
65+
auto kOffset = str_attr("offset");
66+
// The handling of subviews is not as fine as it could be
67+
// We could compose with the identity of the memTy.getShape()
68+
// (at the moment llInv will be of allocShape), but then
69+
// we would need to handle the getReps part more carefuly
70+
// This way we could support more subviews that we don't
71+
// We can implement this generalisation in the future if needed
72+
auto llInv = toLinearLayout(memTy).pseudoinvert();
73+
auto bitwidth = memTy.getElementType().getIntOrFloatBitWidth();
74+
if (isFp4) {
75+
// hacky but well
76+
auto dims = to_vector(llInv.getInDimNames());
77+
auto trans = llInv.getBasis(dims[0], 0, kOffset) == 1;
78+
llInv = LinearLayout::identity1D(2, dims[trans ? 0 : 1], kOffset) * llInv;
79+
bitwidth /= 2;
80+
// The instr_shape comes in number of elements already
81+
}
82+
return build(loc, rewriter, llInv, bitwidth, smemBase, instrShape,
83+
mmaVersion, mmaTy, MNdim);
84+
}
85+
86+
static DotOpMmaSmemLoader
87+
build(Location loc, RewriterBase &rewriter, const LinearLayout &ll,
88+
int bitwidth, Value smemBase, ArrayRef<unsigned> instrShapeArray,
89+
int mmaVersion, std::optional<RankedTensorType> mmaTy = std::nullopt,
90+
std::optional<unsigned> MNdim = std::nullopt) {
91+
// ll is a map from two dimensions (dim0, dim1) or (row, col) into offsets
92+
// and blocks
93+
auto ctx = rewriter.getContext();
94+
auto kOffset = str_attr("offset");
95+
auto kBlock = str_attr("block");
96+
assert(ll.getNumOutDims() == 2);
97+
assert(ll.hasOutDim(kOffset) && ll.hasOutDim(kBlock));
98+
6599
assert(mmaVersion == 3 || mmaVersion == 5);
66100
// Just needed for MMAv3
67101
assert(mmaTy.has_value() == (mmaVersion == 3));
68102
assert(MNdim.has_value() == (mmaVersion == 3));
69103
if (mmaVersion == 3) {
70104
assert(MNdim.value() < 2);
71105
}
106+
auto instrShape = to_vector(instrShapeArray);
72107
assert(instrShape.size() == 2);
73108
auto b = TritonLLVMOpBuilder(loc, rewriter);
74-
// TODO Assert that calling getShmemAffineBase is valid!
75-
76-
// Due to the alignment, we can transform ((base + offset) & 0x3FFFF) >> 4
77-
// into ((base >> 4) & 0x3FFF + (offset >> 4) where offset is in the inner
78-
// loop and ((base >> 4) & 0x3FFF) can be computed once.
79-
assert(cast<triton::gpu::SharedEncodingTrait>(tensor.getEncoding())
80-
.getAlignment() >= 16);
81109

110+
// Due to having a 16B alignment, we can compute the offsets in 128b
111+
// elements
112+
// TODO We should assert in the verifier that the alignment is at least 16B
82113
smemBase = b.ptrtoint(i32_ty, smemBase);
83114
Value baseSrcb128 = b.lshr(smemBase, b.i32_val(4));
84-
int bitwidth = tensor.getElementType().getIntOrFloatBitWidth();
85115

86-
auto ll = toLinearLayout(tensor);
87-
auto kOffset = str_attr("offset");
88-
assert(ll.getNumOutDims() == 2);
89-
auto dims = to_vector(ll.getOutDimNames());
90-
// The linear layout for fp4 represents the matrix as i8s
91-
// For it to play ball with instrShape, which is in terms of the original
92-
// tensor, we need to represent it as i4s
93-
// Interestingly enough, we support i8 x i8 matmul by the looks of it
94-
auto isFp4 =
95-
tensor.getElementType() == IntegerType::get(ctx, 8) && mmaVersion == 5;
96-
auto shape = to_vector(tensor.getShape());
97-
if (isFp4) {
98-
// hacky but well
99-
auto trans = ll.getBasis(kOffset, 0)[0] != 0;
100-
ll = LinearLayout::identity1D(2, kOffset, dims[trans ? 0 : 1]) * ll;
101-
shape[trans ? 0 : 1] *= 2;
102-
bitwidth /= 2;
103-
// The instr_shape comes in number of elements already
104-
}
105-
106-
for (auto [dim, instrSize] : llvm::zip(ll.getOutDimNames(), instrShape)) {
107-
assert(instrSize <= ll.getOutDimSize(dim) &&
108-
"Instr shape is too large for the layout");
109-
}
110-
111-
// TODO Add this to the verifier
112-
// We represent fp4 padded tensors as i8s
113-
auto desc = getDescriptor(ll, instrShape, bitwidth, mmaVersion);
114-
115-
// In case it was a subview, we resize it by composing it with the identity
116-
// of shape getShape (rather than shape getAllocShape, as toLinearLayout
117-
// returns)
118-
// Also in the case of mutlicta where the different CTAs have broadcasting
119-
// (so no 2-CTA MMA) we effectively need to pseudoinvert. This also achieves
120-
// that
121-
auto outDims = to_vector(ll.getOutDimNames());
122-
auto identity = LinearLayout::identity1D(shape[0], outDims[0], outDims[0]) *
123-
LinearLayout::identity1D(shape[1], outDims[1], outDims[1]);
124-
auto llInv = identity.invertAndCompose(ll);
125-
126-
auto blockInstrShape = to_vector(instrShape);
127116
if (mmaVersion == 3) {
128117
auto mndim = MNdim.value();
129118
auto mmaLl = gpu::toLinearLayout(mmaTy.value());
@@ -133,15 +122,15 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
133122
auto mmaWarps = mmaLl.sublayout({kWarp}, {outDims[mndim]}) *
134123
LinearLayout::identity1D(1, kWarp, outDims[1 - mndim]);
135124
// Map from warps to offsets in bitwidth elements
136-
auto warpToOffset = mmaWarps.compose(llInv);
125+
auto warpToOffset = mmaWarps.compose(ll);
137126
// Map from warps to offsets in 128b elements
138127
auto maybeWarpToOffsetb128 =
139128
divideLeft(warpToOffset,
140129
LinearLayout::zeros1D(1, kWarp, kOffset, 128 / bitwidth));
141130
assert(maybeWarpToOffsetb128.has_value());
142131
// zero out the first two warp bases to have a warpgroup to offset map
143-
assert(maybeWarpToOffsetb128->getNumOutDims() == 2);
144132
auto bases = maybeWarpToOffsetb128->getBases();
133+
assert(maybeWarpToOffsetb128->getNumOutDims() == 2);
145134
bases[kWarp][0] = {0, 0};
146135
bases[kWarp][1] = {0, 0};
147136
auto warpGroupToOffsetb128 = LinearLayout(
@@ -152,34 +141,40 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
152141
{{kWarp, warpId}})[0]
153142
.second;
154143
baseSrcb128 = b.add(baseSrcb128, warpStrideb128);
155-
// Increase the instruction shape to describe the size at a warp level
156-
// A bit hacky but well
144+
// Increase the instruction shape to describe the size at a block level
145+
// as the input just describes it at a warp level
157146
int logwgAlongMN = 0;
158147
for (int i = 0; i < warpGroupToOffsetb128.getInDimSizeLog2(kWarp); i++) {
159148
if (warpGroupToOffsetb128.getBasis(kWarp, i, kOffset) != 0) {
160149
logwgAlongMN++;
161150
}
162151
}
163-
blockInstrShape[mndim] *= (1 << logwgAlongMN);
152+
instrShape[mndim] *= (1 << logwgAlongMN);
164153
}
165154

155+
for (auto [dim, instrSize] : llvm::zip(ll.getInDimNames(), instrShape)) {
156+
assert(instrSize <= ll.getInDimSize(dim) &&
157+
"Instruction shape is too large for the layout");
158+
}
159+
160+
auto desc = getDescriptor(ll, instrShape, bitwidth, mmaVersion);
161+
166162
Value baseb128 = b.zext(i64_ty, b.and_(baseSrcb128, b.i32_val(0x3FFF)));
167-
return DotOpMmaSmemLoader(desc, baseb128, llInv, blockInstrShape);
163+
return {desc, baseb128, ll, instrShape};
168164
}
169165

170166
Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter,
171167
Location loc) const {
172168
auto *ctx = loc.getContext();
173169
auto tb = TritonLLVMOpBuilder(loc, rewriter);
174-
auto dims = to_vector(llInv.getInDimNames());
175-
assert((a + 1) * instrShape[0] <= llInv.getInDimSize(dims[0]));
176-
assert((b + 1) * instrShape[1] <= llInv.getInDimSize(dims[1]));
177-
assert(to_vector(llInv.getOutDimNames()) ==
170+
auto dims = to_vector(ll.getInDimNames());
171+
assert((a + 1) * instrShape[0] <= ll.getInDimSize(dims[0]));
172+
assert((b + 1) * instrShape[1] <= ll.getInDimSize(dims[1]));
173+
assert(to_vector(ll.getOutDimNames()) ==
178174
llvm::to_vector(
179175
ArrayRef<StringAttr>{str_attr("offset"), str_attr("block")}));
180-
int32_t totalOffElems = llInv
181-
.apply({{dims[0], a * instrShape[0]},
182-
{dims[1], b * instrShape[1]}})[0]
176+
int32_t totalOffElems = ll.apply({{dims[0], a * instrShape[0]},
177+
{dims[1], b * instrShape[1]}})[0]
183178
.second;
184179
int32_t smemByteOffsetb8 = totalOffElems * desc.bitwidth / 8;
185180
auto currDesc = desc.descriptor;
@@ -198,37 +193,22 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
198193
return {smemLoad(a, b, rewriter, loc), std::nullopt};
199194
}
200195

196+
MMASMEMDescriptor &getDescriptor() { return desc; }
197+
201198
private:
202199
MMASMEMDescriptor desc;
203200
Value baseb128;
204-
LinearLayout llInv;
201+
LinearLayout ll;
205202
SmallVector<unsigned> instrShape;
206203

207204
static MMASMEMDescriptor getDescriptor(const LinearLayout &ll,
208205
ArrayRef<unsigned> instrShape,
209206
int bitwidth, int mmaVersion) {
210207
// ll is a map from allocShape into offsets and blocks
211-
auto inv = ll.pseudoinvert();
212-
auto dims = to_vector(inv.getInDimNames());
208+
auto dims = to_vector(ll.getInDimNames());
213209
auto ctx = dims[0].getContext();
214210
auto kOffset = str_attr("offset");
215211

216-
// Detect tcgen05.mma.cta_group::2 as having two CTAs that are not
217-
// broadcasting
218-
auto kBlock = str_attr("block");
219-
auto twoCTAs = ll.getInDimSize(kBlock) > 1 &&
220-
ll.getBasis(kBlock, 0) != ArrayRef<int32_t>({0, 0});
221-
SmallVector<unsigned> instrShapePerCTA = to_vector(instrShape);
222-
if (twoCTAs) {
223-
// In 2CTA mode we split the tensor into two CTAs
224-
assert(ll.getInDimSize(kBlock) == 2);
225-
if (ll.getBasis(kBlock, 0, dims[0]) != 0) {
226-
instrShapePerCTA[0] /= 2;
227-
} else {
228-
instrShapePerCTA[1] /= 2;
229-
}
230-
}
231-
232212
// Any CTALayout, it's not really used within getCoreMatrixLinearLayout
233213
auto CTALayout = triton::gpu::CTALayoutAttr::getDefault(ctx, 2);
234214

@@ -242,11 +222,19 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
242222
CTALayout);
243223
auto shmemTile =
244224
getCoreMatrixLinearLayout(shmemEnc, /*disableSwizzle=*/false);
245-
// We unpack the bitwidth == 8 tile
225+
// Rename out dims to match the original layout (in case the dims were
226+
// (row, col))
227+
auto outDims = to_vector(shmemTile.getOutDims());
228+
outDims[0].first = dims[0];
229+
outDims[1].first = dims[1];
230+
shmemTile = LinearLayout(shmemTile.getBases(), outDims,
231+
/*requireSurjective=*/false);
232+
// unpack the fp4 layout
246233
if (bitwidth == 4) {
247234
shmemTile =
248235
LinearLayout::identity1D(2, kOffset, dims[1]) * shmemTile;
249236
}
237+
250238
// getCoreMatrixLinearLayout gives the k-contiguous tile
251239
// shmemTile is a layout onto a matrix with shape
252240
// If swizzling != 0: 8 x (8 * swizzling / bitwidth)
@@ -266,10 +254,12 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
266254
// The PTX docs are wrong in a number of ways:
267255
// 1) LBO can be specified for !transposed && swizzled != 0
268256
// PTX says it's assumed to be 1, but we can in fact use it
269-
// 2) LBO / SBO are swapped also for !transposed && swizzled != 0
257+
// 2) LBO / SBO are swapped also for !transposed && swizzled == 0
270258
// PTX just reports this for the transposed case
271-
// Luckily enough the generic logic is much simpler than what's
272-
// described in the docs
259+
// EVEN MORE the computation we do here is conceptually correct
260+
// and it agrees with the tensor descriptors for wgmma or
261+
// tcgen05.mma but not for tcgen05.cp! tcgen05.cp follows the PTX
262+
// docs!
273263
int lbo = 0, sbo = 0;
274264
int leadingDim = transposed ? 0 : 1;
275265
int stridedDim = transposed ? 1 : 0;
@@ -279,13 +269,13 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
279269
std::swap(leadingDim, stridedDim);
280270
}
281271
auto log2RowsTile = shmemTileInv.getInDimSizeLog2(dims[leadingDim]);
282-
if (inv.getInDimSizeLog2(dims[leadingDim]) > log2RowsTile) {
283-
lbo = inv.getBasis(dims[leadingDim], log2RowsTile, kOffset);
272+
if (llvm::Log2_32(instrShape[leadingDim]) > log2RowsTile) {
273+
lbo = ll.getBasis(dims[leadingDim], log2RowsTile, kOffset);
284274
}
285275

286276
auto log2ColsTile = shmemTileInv.getInDimSizeLog2(dims[stridedDim]);
287-
if (inv.getInDimSizeLog2(dims[stridedDim]) > log2ColsTile) {
288-
sbo = inv.getBasis(dims[stridedDim], log2ColsTile, kOffset);
277+
if (llvm::Log2_32(instrShape[stridedDim]) > log2ColsTile) {
278+
sbo = ll.getBasis(dims[stridedDim], log2ColsTile, kOffset);
289279
}
290280

291281
// Pad the tile up to the full instruction shape with the relevant
@@ -294,9 +284,9 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
294284
for (int d : {0, 1}) {
295285
// 'tile' with the atom tile according to the lbo/sbo rules
296286
for (int i = 1;
297-
i < instrShapePerCTA[d] / shmemTileInv.getInDimSize(dims[d]);
287+
i < instrShape[d] / shmemTileInv.getInDimSize(dims[d]);
298288
i *= 2) {
299-
auto stride = inv.getBasis(
289+
auto stride = ll.getBasis(
300290
dims[d], shmemTileInv.getInDimSizeLog2(dims[d]), kOffset);
301291
bases[dims[d]].push_back({stride * i});
302292
}
@@ -316,11 +306,11 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
316306
shmemTileInv *=
317307
LinearLayout::identity1D(1, dims[0], str_attr("block"));
318308

319-
auto quot = getReps(inv, shmemTileInv);
320-
if (quot.has_value()) {
309+
auto reps = getReps(ll, shmemTileInv);
310+
if (reps.has_value()) {
321311
SMEMDescriptor desc;
322312
desc.descriptor = mmaVersion == 5 ? 1ULL << 46 : 0ULL;
323-
// The lbo / sbo is defined wrt. the 128 tile
313+
// The lbo / sbo is defined wrt. the 128b elements
324314
desc.leadDimensionBaseOffset = (lbo * bitwidth / 8) >> 4;
325315
desc.strideDimensionBaseOffset = (sbo * bitwidth / 8) >> 4;
326316
switch (swizzling) {
@@ -342,7 +332,6 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
342332
return {.descriptor = desc,
343333
.swizzlingByteWidth = swizzling,
344334
.bitwidth = bitwidth,
345-
.twoCTAs = twoCTAs,
346335
.transposed = transposed,
347336
.fp4Padded = fp4Padded};
348337
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,24 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
453453
aLoader = std::make_unique<DotOpMmaV5TmemLoader>(a, baseA, aOperandShape,
454454
interleaved, transA);
455455
} else {
456-
auto allocShapeA = getAllocShape(aTensorTy, 1);
456+
auto isFp4a = op.numBitsPerElementA == 4;
457457
aLoader = std::make_unique<DotOpMmaSmemLoader>(DotOpMmaSmemLoader::build(
458-
loc, rewriter, aTensorTy, baseA, aOperandShape, 5));
458+
loc, rewriter, aTensorTy, baseA, aOperandShape, 5, isFp4a));
459459
}
460460

461+
auto isFp4b = op.numBitsPerElementB == 4;
461462
auto allocShapeB = getAllocShape(bTensorTy, 0);
463+
// [Instr shape twoCTAs]
464+
// This division by 2 in 2CTA mode a bit subtle:
465+
// The issue here is that in 2CTA you multiply in one instruction a tensor
466+
// of shape MNK = 256, K, N, and you put it into TMEM of shape 128, K, N*2.
467+
// So to compute the shapePerCTA, on the lhs we can look at the TMEM shape,
468+
// but to compute the shapePerCTA on the rhs, we need to divide by 2.
469+
// Something similar happens when we multiply by 2 the mmaSizeM when creating
470+
// It's a massive code smell tho
462471
DotOpMmaSmemLoader bLoader = DotOpMmaSmemLoader::build(
463-
loc, rewriter, bTensorTy, baseB, {mmaSizeK, mmaSizeN}, 5);
472+
loc, rewriter, bTensorTy, baseB, {mmaSizeK, mmaSizeN / (twoCTAs ? 2 : 1)},
473+
5, isFp4b);
464474

465475
DotConversion::InstDesc desc{mmaSizeM, mmaSizeN, {numRepM, numRepN, numRepK},
466476
transA, transB, interleaved,
@@ -522,6 +532,8 @@ void convertDot(const LLVMTypeConverter &typeConverter,
522532
Value pred, Value useInitAcc,
523533
const DotConversion::InstDesc &desc, int m, int n,
524534
int k) {
535+
// To understand this multiplication by 2, see the comment
536+
// [Instr shape twoCTAs]
525537
Value instDescriptor = createInstDescriptor(
526538
rewriter, op, twoCTAs ? desc.mmaSizeM * 2 : desc.mmaSizeM,
527539
desc.mmaSizeN, desc.transA, desc.transB);
@@ -594,10 +606,8 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
594606
dot.shapeB[0] *= 2;
595607
}
596608

597-
dot.numBitsPerElementA = opKindIsMXFP4 ? getFormatBitSize(op.getAType())
598-
: aTensorTy.getElementTypeBitWidth();
599-
dot.numBitsPerElementB = opKindIsMXFP4 ? getFormatBitSize(op.getBType())
600-
: bTensorTy.getElementTypeBitWidth();
609+
dot.numBitsPerElementA = getFormatBitSize(op.getAType());
610+
dot.numBitsPerElementB = getFormatBitSize(op.getBType());
601611

602612
TritonLLVMOpBuilder tb(loc, rewriter);
603613
Value baseD = tb.ptrtoint(i32_ty, adaptor.getD());

0 commit comments

Comments
 (0)