Skip to content

Commit 868d242

Browse files
authored
[BACKEND] Towards a generic tcgen05.cp lowering (triton-lang#8102)
This is the first PR towards a fully generic `tcgen05.cp` lowering. For now it still has similar limitations as the previous lowering, but it does not assume implicitly the layout of the shared memory and tensor memory. Instead, it checks that the given TMEM and shmem layouts are compatible with the instruction we are lowering to, and if so, computes the matrix descriptor and tmem offsets manually.
1 parent 3b00caa commit 868d242

File tree

4 files changed

+207
-47
lines changed

4 files changed

+207
-47
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,9 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
146146
// the two can be done using transferWithinWarp, without involving LDS
147147
std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
148148

149+
// Create the core layout (atom in the PTX manual) a given nvmma shared encoding
150+
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
151+
bool disableSwizzle);
152+
149153
} // namespace mlir::triton::gpu
150154
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ sharedToLinearLayoutAMDRotating(ArrayRef<int64_t> shape,
183183
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
184184
}
185185

186+
} // namespace
187+
186188
// Returns the layout of a single core matrix which tiles the nvmma layout
187189
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
188190
bool disableSwizzle) {
@@ -195,7 +197,7 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
195197
int maxPhase = shared.getMaxPhase();
196198

197199
int tileRows = 8;
198-
int tileCols = 8 * tileWidthBytes / elemBitWidth;
200+
int tileCols = 8 * std::max(16, tileWidthBytes) / elemBitWidth;
199201
bool isFp4Padded = shared.getFp4Padded();
200202

201203
std::vector<std::vector<int>> bases2D;
@@ -227,8 +229,6 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
227229
return LinearLayout({{S("offset"), bases2D}}, outDimNames);
228230
}
229231

230-
} // namespace
231-
232232
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
233233
NVMMASharedEncodingAttr shared,
234234
bool disableSwizzle) {
@@ -1180,7 +1180,7 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11801180
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);
11811181

11821182
auto blockM = encoding.getBlockM();
1183-
auto blockN = encoding.getBlockN();
1183+
auto blockN = std::min<int32_t>(encoding.getBlockN(), shape[1]);
11841184
assert(blockM == 64 || blockM == 128);
11851185
LinearLayout tile;
11861186
if (blockM == 64) {
@@ -1190,7 +1190,7 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11901190
if (shape[0] > blockM) {
11911191
bases[kRow].push_back({64, 0});
11921192
} else if (shape[1] > blockN) {
1193-
bases[kRow].push_back({0, static_cast<int32_t>(blockN)});
1193+
bases[kRow].push_back({0, blockN});
11941194
} else {
11951195
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
11961196
bases[kRow].push_back({0, 0});

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -689,14 +689,27 @@ LogicalResult TMEMCopyOp::verify() {
689689
}
690690
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
691691
auto sharedEnc =
692+
dyn_cast<triton::gpu::SharedEncodingTrait>(srcTy.getEncoding());
693+
if (sharedEnc.getAlignment() < 16) {
694+
return emitOpError("Source must have at least 16-byte alignment to be "
695+
"representable in a matrix descriptor.");
696+
}
697+
698+
auto mod = getOperation()->getParentOfType<ModuleOp>();
699+
unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
700+
if (numCTAs != 1)
701+
return emitOpError("NYI: Only one CTA is supported for now.");
702+
703+
auto nvmmaEnc =
692704
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());
693-
if (!sharedEnc) {
705+
if (!nvmmaEnc) {
694706
return emitOpError("Source must have nvmma layout.");
695707
}
696-
if (sharedEnc.getTransposed() || sharedEnc.getFp4Padded())
697-
return emitOpError("The source should not be transposed or passed");
708+
// Fp4 we could lift if we needed
709+
if (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())
710+
return emitOpError("The source should not be transposed or padded");
698711
if (isa<TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding())) {
699-
if (sharedEnc.getSwizzlingByteWidth() != 0) {
712+
if (nvmmaEnc.getSwizzlingByteWidth() != 0) {
700713
return emitOpError("The source should not be swizzled for now");
701714
}
702715
if (!triton::gpu::isInnermostContiguous(srcTy, 512)) {
@@ -715,9 +728,10 @@ LogicalResult TMEMCopyOp::verify() {
715728
if (tmemEnc.getBlockM() != 128) {
716729
return emitOpError("Tmem layout ahouls have M=128.");
717730
}
718-
if (sharedEnc.getSwizzlingByteWidth() == 0) {
731+
if (nvmmaEnc.getSwizzlingByteWidth() == 0) {
719732
return emitOpError("Source layout should be swizzled.");
720733
}
734+
// When we lift this, we should make sure we handle unpacked cleanly
721735
if (srcTy.getElementType().getIntOrFloatBitWidth() != 32) {
722736
return emitOpError("Source element type should be 32-bit.");
723737
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 179 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -992,53 +992,195 @@ static void copyScales(ConversionPatternRewriter &rewriter, Location loc,
992992
createCopy(repMorN, repK);
993993
}
994994

995+
static std::optional<std::tuple<int32_t, LinearLayout, LinearLayout,
996+
SmallVector<int64_t>, int32_t>>
997+
getSwizzling(MemDescType shmemTy, MemDescType tmemTy) {
998+
// cvt is a map from Tmem to Shmem
999+
auto tmemLl = toLinearLayout(tmemTy);
1000+
auto shmemLl = toLinearLayout(shmemTy);
1001+
auto inDimNames = to_vector(tmemLl.getInDimNames());
1002+
auto *ctx = inDimNames[0].getContext();
1003+
assert(shmemLl.getInDimSize(str_attr("block")) == 1 && "NYI");
1004+
auto kOffset = str_attr("offset");
1005+
auto kRow = str_attr("row");
1006+
auto kCol = str_attr("col");
1007+
shmemLl = shmemLl.sublayout({kOffset}, to_vector(shmemLl.getOutDimNames()));
1008+
auto cvt = tmemLl.invertAndCompose(shmemLl);
1009+
1010+
int32_t bitwidth = tmemTy.getElementType().getIntOrFloatBitWidth();
1011+
1012+
// Check if the layout is large enough as to check SBO
1013+
// TODO Move to the verifier
1014+
if (shmemLl.getOutDimSizeLog2(str_attr("dim0")) < 4) {
1015+
return std::nullopt;
1016+
}
1017+
// TODO We may need to be careful here if we ever want to support fp4 padded
1018+
// layouts
1019+
if (!shmemLl.isInvertible()) {
1020+
return std::nullopt;
1021+
}
1022+
1023+
// This will be SBO for k-Contiguous layouts (like the ones used in
1024+
// tcgen05.cp)
1025+
auto sbo =
1026+
shmemLl.invert().getBasis(str_attr("dim0"), /*log2(8)=*/3, kOffset);
1027+
1028+
// TODO hardcoded to 128x256b for now
1029+
const SmallVector<int64_t> instrShape = {128, 256 / bitwidth};
1030+
// TODO Move to the verifier perhaps
1031+
// Can we move the tile?
1032+
// TODO We should be able to move any descriptor tile with 128x256b
1033+
// (or 128x128b for unswizzled when it just has one tile)
1034+
for (auto [inDimName, instrSize] : llvm::zip(inDimNames, instrShape)) {
1035+
if (cvt.getInDimSize(inDimName) < instrSize) {
1036+
return std::nullopt;
1037+
}
1038+
}
1039+
1040+
auto CTALayout = getCTALayout(shmemTy.getEncoding());
1041+
1042+
for (int swizzling : {0, 32, 64, 128}) {
1043+
// r = 0, 1, 2, 3
1044+
auto shmemEnc =
1045+
NVMMASharedEncodingAttr::get(ctx, swizzling, /*transposed=*/false,
1046+
bitwidth, /*fp4Padded=*/false, CTALayout);
1047+
auto shmemTile =
1048+
getCoreMatrixLinearLayout(shmemEnc, /*disableSwizzle=*/false);
1049+
// getCoreMatrixLinearLayout gives the k-contiguous tile
1050+
// shmemTile is a layout onto a matrix with shape
1051+
// If swizzling != 0: 8 x (8 * swizzling / bitwidth)
1052+
// If swizzling == 0: 8 x (8 * 16 / bitwidth)
1053+
assert(shmemTile.getOutDimSize(str_attr("dim0")) == 8);
1054+
assert(shmemTile.getOutDimSize(str_attr("dim1")) ==
1055+
8 * std::max(16, swizzling) / bitwidth);
1056+
// The shmemTile is mapped identically into the tmem, so we just need to
1057+
// rename the outDims in shmemTile from dim0, dim1 to row, col
1058+
auto cvtTileInverted =
1059+
LinearLayout(shmemTile.getBases(), {str_attr("row"), str_attr("col")});
1060+
// The tile should be invertible, so we consider it as a map from row, col
1061+
// to offset
1062+
// nb. Working with the map from row, col to offset is important to handle
1063+
// the tcgen05.cp instructions that do broadcasting
1064+
auto cvtTile = cvtTileInverted.invert();
1065+
// The sbo stride shall not touch the core tile
1066+
if (sbo < cvtTile.getOutDimSize(kOffset))
1067+
continue;
1068+
1069+
// As we are copying instrShape[0] columns in one go, to be able to
1070+
// represent this in the descriptor, we need to have a constant "stride"
1071+
// along the row dimension from row=8 until the last row.
1072+
auto bases = cvtTile.getBases();
1073+
for (int i = 1; i < instrShape[0] / 8; i *= 2) {
1074+
bases[kRow].push_back({sbo * i});
1075+
}
1076+
cvtTile = LinearLayout(bases, {{kOffset, sbo * (instrShape[0] / 8)}},
1077+
/*requireSurjective=*/false);
1078+
1079+
auto quot = divideLeft(cvt, cvtTile);
1080+
if (quot.has_value()) {
1081+
if (auto nvmma = dyn_cast<NVMMASharedEncodingAttr>(shmemEnc)) {
1082+
assert(nvmma.getSwizzlingByteWidth() == swizzling);
1083+
}
1084+
return std::make_tuple(swizzling, *quot, cvtTile, instrShape, sbo);
1085+
}
1086+
}
1087+
return std::nullopt;
1088+
}
1089+
9951090
static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc,
9961091
const TypeConverter *typeConverter,
9971092
triton::nvidia_gpu::TMEMCopyOp op, Value src,
998-
Value dst, Value pred) {
1093+
Value baseDst, Value pred) {
9991094
auto b = TritonLLVMOpBuilder(loc, rewriter);
1095+
auto *ctx = op.getContext();
1096+
auto kOffset = str_attr("offset");
1097+
auto kRow = str_attr("row");
1098+
auto kCol = str_attr("col");
1099+
10001100
MemDescType srcTy = op.getSrc().getType();
10011101
MemDescType dstTy = op.getDst().getType();
1102+
1103+
auto sharedLl = toLinearLayout(srcTy);
1104+
sharedLl =
1105+
sharedLl.sublayout({kOffset}, to_vector(sharedLl.getOutDimNames()));
1106+
auto tmemLl = toLinearLayout(dstTy);
1107+
auto cvt = tmemLl.invertAndCompose(sharedLl);
1108+
1109+
auto bitwidth = srcTy.getElementType().getIntOrFloatBitWidth();
1110+
// Need to find the shmem tile that matches
1111+
auto maybeSwizzling = getSwizzling(srcTy, dstTy);
1112+
assert(maybeSwizzling.has_value());
1113+
auto [swizzling, quot, tile, tileShape, sbo] = std::move(*maybeSwizzling);
1114+
1115+
auto reps = zerosLike(tile) * quot;
1116+
1117+
// Get shmem ptr
1118+
// TODO We should not allow splitting along the swizzling pattern
10021119
Type elemTy = typeConverter->convertType(srcTy.getElementType());
10031120
auto smemObj =
10041121
LLVM::getSharedMemoryObjectFromStruct(loc, src, elemTy, rewriter);
1005-
Value baseSrc = smemObj.getShmemAffineBase(loc, rewriter, srcTy);
1122+
Value baseSrcInt =
1123+
b.ptrtoint(i32_ty, smemObj.getShmemAffineBase(loc, rewriter, srcTy));
1124+
// We checked in the verifier that the alignment is at least 16
1125+
Value baseSrcIntShr4 = b.lshr(baseSrcInt, b.i32_val(4));
1126+
1127+
// Set common fields in the SMEMDescriptor
1128+
SMEMDescriptor desc;
1129+
desc.baseAddress = 0;
1130+
// For K-contig, leadDimension is assumed to be 1
1131+
desc.leadDimensionBaseOffset = 1;
1132+
// SBO is in elements and we have to pass it to bits and right shift by 4
1133+
desc.strideDimensionBaseOffset = ((sbo * (bitwidth / 8)) >> 4);
1134+
desc.matrixBaseOffset = 0;
1135+
switch (swizzling) {
1136+
case 0:
1137+
desc.swizzlingMode = 0;
1138+
break;
1139+
case 32:
1140+
desc.swizzlingMode = 3;
1141+
break;
1142+
case 64:
1143+
desc.swizzlingMode = 2;
1144+
break;
1145+
case 128:
1146+
desc.swizzlingMode = 1;
1147+
break;
1148+
default:
1149+
llvm::report_fatal_error("Unsupported swizzling size.");
1150+
}
10061151

1007-
Value baseDst = dst;
1008-
assert(srcTy.getElementType().getIntOrFloatBitWidth() == 32);
1009-
1010-
int blockN =
1011-
cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(dstTy.getEncoding())
1012-
.getBlockN();
1013-
// Currently, hardcoded to 128x256b message.
1014-
std::array<int, 2> instShape = {128, 8};
1015-
int repNPerBlock = blockN / instShape[1];
1016-
auto createCopy = [&](int repM, int repN) {
1017-
Value zero = b.i32_val(0);
1018-
SmallVector<int64_t> shape(op.getSrc().getType().getShape());
1019-
DotOpMmaV5SmemLoader smemLoader = DotOpMmaV5SmemLoader(
1020-
op.getSrc(), baseSrc, shape, op.getSrc().getType().getAllocShape(),
1021-
zero, 1, /*trans=*/false, {128, 8},
1022-
op.getSrc().getType().getElementType().getIntOrFloatBitWidth(),
1023-
rewriter, loc);
1024-
for (int m = 0; m < repM; m++) {
1025-
for (int n = 0; n < repN; n++) {
1026-
int colIndx =
1027-
(n % repNPerBlock) * instShape[1] +
1028-
m * repNPerBlock * instShape[1] +
1029-
(n / repNPerBlock) * (srcTy.getDimSize(0) / instShape[0]) * blockN;
1030-
auto colOffset = b.i32_val(colIndx);
1031-
auto tmemAddr = b.add(b.ptrtoint(i32_ty, baseDst), colOffset);
1032-
Value smemDesc = smemLoader.smemLoad(m, n, rewriter, loc);
1033-
createTcgen05Cp(rewriter, loc, tmemAddr, smemDesc, pred,
1034-
/*scales=*/false);
1035-
}
1152+
// Make sure we don't have to iterate along the rows
1153+
assert(tile.getInDimSize(kRow) == cvt.getInDimSize(kRow) && "NYI");
1154+
assert(tileShape[1] <= tile.getInDimSize(kCol) && "NYI");
1155+
int elementBytes = bitwidth / 8;
1156+
for (int col = 0; col < reps.getInDimSize(kCol);
1157+
col += tile.getInDimSize(kCol)) {
1158+
// Compute base offset for the swizzling pattern
1159+
int32_t off = reps.apply({{kRow, 0}, {kCol, col}})[0].second;
1160+
desc.matrixBaseOffset = (off * elementBytes / 128) & 0x7;
1161+
uint64_t descBase = desc.descriptor;
1162+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor
1163+
descBase |= (1ULL << 46);
1164+
Value descValBase = b.int_val(64, desc.descriptor);
1165+
for (int offset = 0; offset < tile.getInDimSize(kCol);
1166+
offset += tileShape[1]) {
1167+
// Compute total offset of the current message
1168+
int32_t totalOffElems =
1169+
cvt.apply({{kRow, 0}, {kCol, col + offset}})[0].second;
1170+
int32_t smemByteOffset = totalOffElems * elementBytes;
1171+
int32_t smemByteOffsetShr4 = smemByteOffset >> 4;
1172+
// We could fold this add into the descBase if we wanted to
1173+
Value baseAddr = b.add(baseSrcIntShr4, b.i32_val(smemByteOffsetShr4));
1174+
Value baseSrcDesc = b.zext(i64_ty, b.and_(baseAddr, b.i32_val(0x3FFF)));
1175+
// Add the base address to the descriptor
1176+
Value descVal = b.or_(descValBase, baseSrcDesc, /*disjoint=*/true);
1177+
auto tmemAddr =
1178+
b.or_(b.ptrtoint(i32_ty, baseDst), b.i32_val(col + offset),
1179+
/*disjoint=*/true);
1180+
createTcgen05Cp(rewriter, loc, tmemAddr, descVal, pred,
1181+
/*scales=*/false);
10361182
}
1037-
};
1038-
1039-
int repM = srcTy.getDimSize(0) / instShape[0];
1040-
int repN = srcTy.getDimSize(1) / instShape[1];
1041-
createCopy(repM, repN);
1183+
}
10421184
}
10431185

10441186
struct TensorMemoryCopyOpConversion
@@ -1048,7 +1190,7 @@ struct TensorMemoryCopyOpConversion
10481190
LogicalResult
10491191
matchAndRewrite(triton::nvidia_gpu::TMEMCopyOp op, OpAdaptor adaptor,
10501192
ConversionPatternRewriter &rewriter) const override {
1051-
1193+
assert(lookupNumCTAs(rewriter) == 1 && "NYI");
10521194
Location loc = op->getLoc();
10531195
Value pred = LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter);
10541196
if (isa<TensorMemoryScalesEncodingAttr>(

0 commit comments

Comments
 (0)