Skip to content

Commit b188033

Browse files
authored
[BACKEND] Move cp.async to better lowering sequence (#7304)
In this PR we rehash the lowering of cp.async to reuse the previous optimisations. Net positive on internal benchmarks
1 parent 6fcbac9 commit b188033

File tree

5 files changed

+149
-88
lines changed

5 files changed

+149
-88
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,18 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
566566
ConversionPatternRewriter &rewriter,
567567
const TargetInfoBase &targetInfo);
568568

569+
// Lower an ld/st-like operation given a layout and a callback that creates the
570+
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
571+
// not, and returns the output values.
572+
SmallVector<Value> lowerLdSt(
573+
Location loc, MLIRContext *ctx, LinearLayout cvt,
574+
ArrayRef<Value> valsArray, // Input for store, output for load
575+
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
576+
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
577+
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
578+
ArrayRef<Value>, Value, int, VectorType)>
579+
lowerInst);
580+
569581
// Lower local_load/local_store via ld.shared/st.shared
570582
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
571583
// Map from registers to offset

include/triton/Tools/LinearLayout.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,11 @@ class ColumnAction {
838838
// Inverse of the action
839839
ColumnAction inverse() const;
840840

841+
static ColumnAction identity(StringAttr inDim, size_t inSizeLog2) {
842+
return ColumnAction(llvm::to_vector(llvm::seq<size_t>(inSizeLog2)), inDim,
843+
inSizeLog2);
844+
}
845+
841846
// Returns true if the action is the identity
842847
bool isIdentity() const { return m_isIdentity; }
843848

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "llvm/ADT/STLExtras.h"
1313
#include "llvm/Support/MathExtras.h"
1414

15+
#include <functional>
16+
1517
#if defined(_MSC_VER) && !defined(__clang__)
1618
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
1719
#include <intrin.h>
@@ -513,20 +515,28 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
513515
}
514516

515517
std::pair<int, ColumnAction>
516-
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth) {
518+
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
519+
std::optional<int> maybeMaxVecElems = std::nullopt) {
517520
// Find the largest vectorisation we can use:
518521
StringAttr kReg = str_attr("register");
519522
StringAttr kOffset = str_attr("offset");
520523
LinearLayout quot;
521524
LinearLayout tile;
522525
ColumnAction permutation;
523-
for (int v = 128 / bitwidth; v >= 1; v /= 2) {
526+
// If there are restrictions on the vectorisation, we don't allow
527+
// permutations.
528+
auto allowPerm = !maybeMaxVecElems.has_value();
529+
auto maxVecElems = maybeMaxVecElems.value_or(128 / bitwidth);
530+
for (int v = maxVecElems; v >= 1; v /= 2) {
524531
tile = LinearLayout::identity1D(v, kReg, kOffset);
525532
auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true);
526533
if (!maybePerm) {
527534
continue;
528535
}
529536
permutation = *maybePerm;
537+
if (!allowPerm && !permutation.isIdentity()) {
538+
continue;
539+
}
530540
auto newCvt = permutation.apply(cvt);
531541
auto maybeQuot = divideLeft(newCvt, tile);
532542
if (!maybeQuot) {
@@ -544,6 +554,39 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
544554
Type llvmElemTy, Value smemBase,
545555
ConversionPatternRewriter &rewriter,
546556
const TargetInfoBase &targetInfo) {
557+
558+
bool isStore = !valsArray.empty();
559+
auto b = TritonLLVMOpBuilder(loc, rewriter);
560+
561+
auto emitCpAsync = [&](ConversionPatternRewriter &rewriter, Location loc,
562+
ArrayRef<Value> vals, Value shmemAddr, int idx,
563+
VectorType vecTy) -> SmallVector<Value> {
564+
auto length = vecTy.getNumElements();
565+
if (isStore) {
566+
Value valsVec =
567+
packLLVector(loc, ArrayRef<Value>(vals).slice(idx, length), rewriter);
568+
targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec,
569+
/*pred=*/b.true_val());
570+
return {};
571+
} else {
572+
assert(vals.empty());
573+
Value valsVec = targetInfo.loadDShared(
574+
rewriter, loc, shmemAddr, std::nullopt, vecTy, /*pred=*/b.true_val());
575+
return unpackLLVector(loc, valsVec, rewriter);
576+
}
577+
};
578+
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase, rewriter,
579+
targetInfo, {}, emitCpAsync);
580+
}
581+
582+
SmallVector<Value> lowerLdSt(
583+
Location loc, MLIRContext *ctx, LinearLayout cvt,
584+
ArrayRef<Value> valsArray, // Input for store, output for load
585+
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
586+
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
587+
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
588+
ArrayRef<Value>, Value, int, VectorType)>
589+
lowerInst) {
547590
auto vals = to_vector(valsArray);
548591
bool isStore = !vals.empty();
549592
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -554,7 +597,8 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
554597
auto kOffset = str_attr("offset");
555598
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
556599

557-
auto [elemsPerVec, permutation] = largestVectorisation(ctx, cvt, bitwidth);
600+
auto [elemsPerVec, permutation] =
601+
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
558602

559603
cvt = permutation.apply(cvt);
560604
if (isStore) {
@@ -586,6 +630,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
586630
{{kReg, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}})[0]
587631
.second;
588632
SmallVector<Value> outVals;
633+
auto vecTy = vec_ty(llvmElemTy, elemsPerVec);
589634
for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) {
590635
auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second;
591636
auto regIdxI8 = regIdx * (bitwidth / 8);
@@ -598,19 +643,8 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
598643
Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8));
599644
auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset,
600645
LLVM::GEPNoWrapFlags::inbounds);
601-
// Lezcano: Do we want to use getFreeVariableMasks for pred or nah?
602-
if (isStore) {
603-
Value valsVec = packLLVector(
604-
loc, ArrayRef<Value>(vals).slice(i + j, elemsPerVec), rewriter);
605-
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
606-
/*pred=*/b.true_val());
607-
} else {
608-
Value valsVec =
609-
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
610-
vec_ty(llvmElemTy, elemsPerVec),
611-
/*pred=*/b.true_val());
612-
llvm::append_range(outVals, unpackLLVector(loc, valsVec, rewriter));
613-
}
646+
llvm::append_range(outVals,
647+
lowerInst(rewriter, loc, vals, vecAddr, i + j, vecTy));
614648
}
615649
}
616650

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
656656
%index = arith.constant 1 : i32
657657

658658
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
659-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
659+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
660660
// CHECK: nvvm.cp.async.commit.group
661661
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr<f32>, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable>
662662
ttg.async_commit_group
@@ -740,10 +740,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
740740
%tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #smem, mutable>
741741
%index = arith.constant 1 : i32
742742

743-
// CHECK: llvm.mlir.constant(0 : i32) : i32
744-
// CHECK: llvm.mlir.constant(16 : i32) : i32
745-
// CHECK: llvm.mul
746-
// CHECK: llvm.add
747743
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;"
748744
// CHECK: llvm.inline_asm
749745
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,6 @@ struct AsyncCopyGlobalToLocalOpConversion
11571157
auto srcTy = op.getSrc().getType();
11581158
auto dstTy = op.getResult().getType();
11591159
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
1160-
auto srcLayout = srcTy.getEncoding();
11611160

11621161
Value llDst = adaptor.getResult();
11631162
Value llSrc = adaptor.getSrc();
@@ -1167,27 +1166,40 @@ struct AsyncCopyGlobalToLocalOpConversion
11671166
// %src
11681167
auto srcElems = unpackLLElements(loc, llSrc, rewriter);
11691168

1170-
// %dst
1171-
auto smemObj =
1172-
getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
11731169
// %mask
11741170
SmallVector<Value> maskElems;
11751171
if (llMask) {
11761172
maskElems = unpackLLElements(loc, llMask, rewriter);
11771173
assert(srcElems.size() == maskElems.size());
11781174
}
11791175

1176+
// We assume other = 0, see XXX(Keren) below
11801177
// %other
1181-
SmallVector<Value> otherElems;
1182-
if (llOther) {
1183-
// FIXME(Keren): assume other is 0 for now.
1184-
//
1185-
// It's not necessary for now because the pipeline pass will skip
1186-
// generating insert_slice_async if the load op has any "other" tensor.
1187-
otherElems = unpackLLElements(loc, llOther, rewriter);
1188-
assert(srcElems.size() == otherElems.size());
1178+
// SmallVector<Value> otherElems;
1179+
// if (llOther) {
1180+
// otherElems = unpackLLElements(loc, llOther, rewriter);
1181+
// assert(srcElems.size() == otherElems.size());
1182+
// }
1183+
1184+
// zip(src, mask)
1185+
SmallVector<Value> vals;
1186+
auto ptrTy = srcElems[0].getType();
1187+
auto structTy =
1188+
LLVM::LLVMStructType::getLiteral(ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189+
for (int i = 0; i < srcElems.size(); i++) {
1190+
Value packedArr = rewriter.create<LLVM::UndefOp>(loc, structTy);
1191+
packedArr = b.insert_val(packedArr, srcElems[i], 0);
1192+
auto maskElem = llMask ? maskElems[i] : b.false_val();
1193+
packedArr = b.insert_val(packedArr, maskElem, 1);
1194+
vals.push_back(packedArr);
11891195
}
11901196

1197+
// Remove broadcasted registers
1198+
auto srcLayout = ttg::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
1199+
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
1200+
srcLayout = removeBroadcastSrc.apply(srcLayout);
1201+
vals = removeBroadcastSrc.apply(vals);
1202+
11911203
// We can load N elements at a time if:
11921204
// 1. Every group of N source pointers are contiguous. For example, if
11931205
// N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1198,25 +1210,16 @@ struct AsyncCopyGlobalToLocalOpConversion
11981210
if (mask) {
11991211
maxVec = std::min(maxVec, getMaskAlignment(mask));
12001212
}
1213+
// The maximum vector size is 128 bits on NVIDIA GPUs.
1214+
maxVec = std::min(maxVec, 128 / resElemTy.getIntOrFloatBitWidth());
12011215

1202-
// Addresses to store into, one per `vecTy`.
1203-
VectorType vecTy;
1204-
SmallVector<Value> shmemAddrs;
1205-
bool ok = emitTransferBetweenRegistersAndShared(
1206-
srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207-
[&](VectorType vecTy_, Value shmemAddr) {
1208-
vecTy = vecTy_;
1209-
shmemAddrs.push_back(shmemAddr);
1210-
});
1211-
assert(ok);
1212-
1213-
int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8;
1214-
assert(llvm::isPowerOf2_32(vecBytes));
1216+
int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth() / 8;
12151217
if (vecBytes < 4) {
12161218
return emitError(loc, "cp.async does not support transfers smaller than "
12171219
"4 bytes; calculated this as ")
12181220
<< vecBytes << " bytes";
12191221
}
1222+
assert(vecBytes == 16 || vecBytes == 8 || vecBytes == 4);
12201223

12211224
auto freeVarMasks = getFreeVariableMasks(srcTy);
12221225
// NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1225,52 +1228,63 @@ struct AsyncCopyGlobalToLocalOpConversion
12251228
freeVarMasks[str_attr("block")] = 0;
12261229
Value threadPred =
12271230
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
1228-
uint32_t regMask = freeVarMasks[str_attr("reg")];
12291231

1230-
for (int i = 0; i < shmemAddrs.size(); i++) {
1231-
// It's possible that vecTy is larger than 128 bits, in which case we have
1232-
// to use multiple cp.async instructions.
1233-
int wordBytes = std::min(vecBytes, 16);
1234-
int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth();
1235-
int numWordsInVec = std::max(1, vecBytes / wordBytes);
1236-
for (int j = 0; j < numWordsInVec; j++) {
1237-
int elemIdx = i * vecTy.getNumElements() + j * wordElems;
1238-
1239-
if (!isCanonicalIndex(elemIdx, regMask)) {
1240-
continue; // Skip redundant registers
1241-
}
1232+
auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool(llMask)](
1233+
ConversionPatternRewriter &rewriter, Location loc,
1234+
ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235+
VectorType vecTy) -> SmallVector<Value> {
1236+
assert(isa<VectorType>(vecTy));
1237+
auto *ctx = rewriter.getContext();
1238+
auto elemTy = vecTy.getElementType();
1239+
auto nBytes = vecTy.getNumElements() * elemTy.getIntOrFloatBitWidth() / 8;
1240+
assert(nBytes == 16 || nBytes == 8 || nBytes == 4);
1241+
// Tune CG and CA.
1242+
CacheModifier srcCacheModifier =
1243+
nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244+
1245+
auto structElem = vals[startIdx];
1246+
auto srcElem = b.extract_val(ptrTy, structElem, 0);
1247+
auto maskElem = b.extract_val(i1_ty, structElem, 1);
12421248

1243-
// Tune CG and CA.
1244-
CacheModifier srcCacheModifier =
1245-
wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246-
assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4);
1247-
1248-
PTXBuilder ptxBuilder;
1249-
auto &copyAsyncOp =
1250-
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
1251-
auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddrs[i], "r",
1252-
/*offset=*/j * wordBytes);
1253-
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx], "l");
1254-
auto *copySize = ptxBuilder.newConstantOperand(wordBytes);
1255-
auto *srcSize = copySize;
1256-
if (op.getMask()) {
1257-
// We don't use predicate in this case, setting src-size to 0
1258-
// if there's any mask. cp.async will automatically fill the
1259-
// remaining slots with 0 if cp-size > src-size.
1260-
// XXX(Keren): Always assume other = 0 for now.
1261-
// When 'other != 0' is supported, we will need to fold the
1262-
// op.getMask() and redundantDataMask() into the same predicate, the
1263-
// way it is done for LoadOp.
1264-
auto selectOp =
1265-
b.select(maskElems[elemIdx], b.i32_val(wordBytes), b.i32_val(0));
1266-
srcSize = ptxBuilder.newOperand(selectOp, "r");
1267-
}
1268-
1269-
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1270-
.maybePredicate(threadPred);
1271-
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
1249+
PTXBuilder ptxBuilder;
1250+
auto &copyAsyncOp =
1251+
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
1252+
auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddr, "r");
1253+
auto *srcOperand = ptxBuilder.newAddrOperand(srcElem, "l");
1254+
auto *copySize = ptxBuilder.newConstantOperand(nBytes);
1255+
auto *srcSize = copySize;
1256+
if (hasMask) {
1257+
// We don't use predicate in this case, setting src-size to 0
1258+
// if there's any mask. cp.async will automatically fill the
1259+
// remaining slots with 0 if cp-size > src-size.
1260+
// XXX(Keren): Always assume other = 0 for now.
1261+
// When 'other != 0' is supported, we will need to fold the
1262+
// op.getMask() and redundantDataMask() into the same predicate, the
1263+
// way it is done for LoadOp.
1264+
auto selectOp = b.select(maskElem, b.i32_val(nBytes), b.i32_val(0));
1265+
srcSize = ptxBuilder.newOperand(selectOp, "r");
12721266
}
1267+
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1268+
.maybePredicate(threadPred);
1269+
ptxBuilder.launch(rewriter, loc, void_ty(ctx));
1270+
return {};
1271+
};
1272+
1273+
// %dst
1274+
auto smemObj =
1275+
getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
1276+
auto smemLayout =
1277+
ttg::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
1278+
auto cvt = srcLayout.invertAndCompose(smemLayout);
1279+
if (!cvt.isTrivialOver({str_attr("block")})) {
1280+
return emitError(loc,
1281+
"cp.async does not support non-trivial block dimension");
12731282
}
1283+
cvt = cvt.sublayout(
1284+
{str_attr("register"), str_attr("lane"), str_attr("warp")},
1285+
{str_attr("offset")});
1286+
lowerLdSt(loc, ctx, cvt, vals, resElemTy, smemObj.getBase(), rewriter,
1287+
targetInfo, maxVec, emitCpAsync);
12741288

12751289
// Drop the result token.
12761290
Value zero = rewriter.create<LLVM::ConstantOp>(

0 commit comments

Comments
 (0)