Skip to content

Commit 8791ac1

Browse files
authored
Revert "[BACKEND] Move cp.async to better lowering sequence (#7304)" (#7309)
This reverts commit b188033. As it regresses for now.
1 parent 48229b7 commit 8791ac1

File tree

5 files changed

+88
-149
lines changed

5 files changed

+88
-149
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,18 +566,6 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
566566
ConversionPatternRewriter &rewriter,
567567
const TargetInfoBase &targetInfo);
568568

569-
// Lower an ld/st-like operation given a layout and a callback that creates the
570-
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
571-
// not, and returns the output values.
572-
SmallVector<Value> lowerLdSt(
573-
Location loc, MLIRContext *ctx, LinearLayout cvt,
574-
ArrayRef<Value> valsArray, // Input for store, output for load
575-
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
576-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
577-
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
578-
ArrayRef<Value>, Value, int, VectorType)>
579-
lowerInst);
580-
581569
// Lower local_load/local_store via ld.shared/st.shared
582570
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
583571
// Map from registers to offset

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -838,11 +838,6 @@ class ColumnAction {
838838
// Inverse of the action
839839
ColumnAction inverse() const;
840840

841-
static ColumnAction identity(StringAttr inDim, size_t inSizeLog2) {
842-
return ColumnAction(llvm::to_vector(llvm::seq<size_t>(inSizeLog2)), inDim,
843-
inSizeLog2);
844-
}
845-
846841
// Returns true if the action is the identity
847842
bool isIdentity() const { return m_isIdentity; }
848843

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include "llvm/ADT/STLExtras.h"
1313
#include "llvm/Support/MathExtras.h"
1414

15-
#include <functional>
16-
1715
#if defined(_MSC_VER) && !defined(__clang__)
1816
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
1917
#include <intrin.h>
@@ -515,28 +513,20 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
515513
}
516514

517515
std::pair<int, ColumnAction>
518-
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
519-
std::optional<int> maybeMaxVecElems = std::nullopt) {
516+
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth) {
520517
// Find the largest vectorisation we can use:
521518
StringAttr kReg = str_attr("register");
522519
StringAttr kOffset = str_attr("offset");
523520
LinearLayout quot;
524521
LinearLayout tile;
525522
ColumnAction permutation;
526-
// If there are restrictions on the vectorisation, we don't allow
527-
// permutations.
528-
auto allowPerm = !maybeMaxVecElems.has_value();
529-
auto maxVecElems = maybeMaxVecElems.value_or(128 / bitwidth);
530-
for (int v = maxVecElems; v >= 1; v /= 2) {
523+
for (int v = 128 / bitwidth; v >= 1; v /= 2) {
531524
tile = LinearLayout::identity1D(v, kReg, kOffset);
532525
auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true);
533526
if (!maybePerm) {
534527
continue;
535528
}
536529
permutation = *maybePerm;
537-
if (!allowPerm && !permutation.isIdentity()) {
538-
continue;
539-
}
540530
auto newCvt = permutation.apply(cvt);
541531
auto maybeQuot = divideLeft(newCvt, tile);
542532
if (!maybeQuot) {
@@ -554,39 +544,6 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
554544
Type llvmElemTy, Value smemBase,
555545
ConversionPatternRewriter &rewriter,
556546
const TargetInfoBase &targetInfo) {
557-
558-
bool isStore = !valsArray.empty();
559-
auto b = TritonLLVMOpBuilder(loc, rewriter);
560-
561-
auto emitCpAsync = [&](ConversionPatternRewriter &rewriter, Location loc,
562-
ArrayRef<Value> vals, Value shmemAddr, int idx,
563-
VectorType vecTy) -> SmallVector<Value> {
564-
auto length = vecTy.getNumElements();
565-
if (isStore) {
566-
Value valsVec =
567-
packLLVector(loc, ArrayRef<Value>(vals).slice(idx, length), rewriter);
568-
targetInfo.storeDShared(rewriter, loc, shmemAddr, std::nullopt, valsVec,
569-
/*pred=*/b.true_val());
570-
return {};
571-
} else {
572-
assert(vals.empty());
573-
Value valsVec = targetInfo.loadDShared(
574-
rewriter, loc, shmemAddr, std::nullopt, vecTy, /*pred=*/b.true_val());
575-
return unpackLLVector(loc, valsVec, rewriter);
576-
}
577-
};
578-
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase, rewriter,
579-
targetInfo, {}, emitCpAsync);
580-
}
581-
582-
SmallVector<Value> lowerLdSt(
583-
Location loc, MLIRContext *ctx, LinearLayout cvt,
584-
ArrayRef<Value> valsArray, // Input for store, output for load
585-
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
586-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
587-
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
588-
ArrayRef<Value>, Value, int, VectorType)>
589-
lowerInst) {
590547
auto vals = to_vector(valsArray);
591548
bool isStore = !vals.empty();
592549
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -597,8 +554,7 @@ SmallVector<Value> lowerLdSt(
597554
auto kOffset = str_attr("offset");
598555
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
599556

600-
auto [elemsPerVec, permutation] =
601-
largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems);
557+
auto [elemsPerVec, permutation] = largestVectorisation(ctx, cvt, bitwidth);
602558

603559
cvt = permutation.apply(cvt);
604560
if (isStore) {
@@ -630,7 +586,6 @@ SmallVector<Value> lowerLdSt(
630586
{{kReg, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}})[0]
631587
.second;
632588
SmallVector<Value> outVals;
633-
auto vecTy = vec_ty(llvmElemTy, elemsPerVec);
634589
for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) {
635590
auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second;
636591
auto regIdxI8 = regIdx * (bitwidth / 8);
@@ -643,8 +598,19 @@ SmallVector<Value> lowerLdSt(
643598
Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8));
644599
auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset,
645600
LLVM::GEPNoWrapFlags::inbounds);
646-
llvm::append_range(outVals,
647-
lowerInst(rewriter, loc, vals, vecAddr, i + j, vecTy));
601+
// Lezcano: Do we want to use getFreeVariableMasks for pred or nah?
602+
if (isStore) {
603+
Value valsVec = packLLVector(
604+
loc, ArrayRef<Value>(vals).slice(i + j, elemsPerVec), rewriter);
605+
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
606+
/*pred=*/b.true_val());
607+
} else {
608+
Value valsVec =
609+
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
610+
vec_ty(llvmElemTy, elemsPerVec),
611+
/*pred=*/b.true_val());
612+
llvm::append_range(outVals, unpackLLVector(loc, valsVec, rewriter));
613+
}
648614
}
649615
}
650616

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
656656
%index = arith.constant 1 : i32
657657

658658
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
659-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
659+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
660660
// CHECK: nvvm.cp.async.commit.group
661661
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr<f32>, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable>
662662
ttg.async_commit_group
@@ -740,6 +740,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
740740
%tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #smem, mutable>
741741
%index = arith.constant 1 : i32
742742

743+
// CHECK: llvm.mlir.constant(0 : i32) : i32
744+
// CHECK: llvm.mlir.constant(16 : i32) : i32
745+
// CHECK: llvm.mul
746+
// CHECK: llvm.add
743747
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;"
744748
// CHECK: llvm.inline_asm
745749
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 67 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ struct AsyncCopyGlobalToLocalOpConversion
11571157
auto srcTy = op.getSrc().getType();
11581158
auto dstTy = op.getResult().getType();
11591159
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
1160+
auto srcLayout = srcTy.getEncoding();
11601161

11611162
Value llDst = adaptor.getResult();
11621163
Value llSrc = adaptor.getSrc();
@@ -1166,40 +1167,27 @@ struct AsyncCopyGlobalToLocalOpConversion
11661167
// %src
11671168
auto srcElems = unpackLLElements(loc, llSrc, rewriter);
11681169

1170+
// %dst
1171+
auto smemObj =
1172+
getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
11691173
// %mask
11701174
SmallVector<Value> maskElems;
11711175
if (llMask) {
11721176
maskElems = unpackLLElements(loc, llMask, rewriter);
11731177
assert(srcElems.size() == maskElems.size());
11741178
}
11751179

1176-
// We assume other = 0, see XXX(Keren) below
11771180
// %other
1178-
// SmallVector<Value> otherElems;
1179-
// if (llOther) {
1180-
// otherElems = unpackLLElements(loc, llOther, rewriter);
1181-
// assert(srcElems.size() == otherElems.size());
1182-
// }
1183-
1184-
// zip(src, mask)
1185-
SmallVector<Value> vals;
1186-
auto ptrTy = srcElems[0].getType();
1187-
auto structTy =
1188-
LLVM::LLVMStructType::getLiteral(ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189-
for (int i = 0; i < srcElems.size(); i++) {
1190-
Value packedArr = rewriter.create<LLVM::UndefOp>(loc, structTy);
1191-
packedArr = b.insert_val(packedArr, srcElems[i], 0);
1192-
auto maskElem = llMask ? maskElems[i] : b.false_val();
1193-
packedArr = b.insert_val(packedArr, maskElem, 1);
1194-
vals.push_back(packedArr);
1181+
SmallVector<Value> otherElems;
1182+
if (llOther) {
1183+
// FIXME(Keren): assume other is 0 for now.
1184+
//
1185+
// It's not necessary for now because the pipeline pass will skip
1186+
// generating insert_slice_async if the load op has any "other" tensor.
1187+
otherElems = unpackLLElements(loc, llOther, rewriter);
1188+
assert(srcElems.size() == otherElems.size());
11951189
}
11961190

1197-
// Remove broadcasted registers
1198-
auto srcLayout = ttg::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
1199-
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
1200-
srcLayout = removeBroadcastSrc.apply(srcLayout);
1201-
vals = removeBroadcastSrc.apply(vals);
1202-
12031191
// We can load N elements at a time if:
12041192
// 1. Every group of N source pointers are contiguous. For example, if
12051193
// N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1210,16 +1198,25 @@ struct AsyncCopyGlobalToLocalOpConversion
12101198
if (mask) {
12111199
maxVec = std::min(maxVec, getMaskAlignment(mask));
12121200
}
1213-
// The maximum vector size is 128 bits on NVIDIA GPUs.
1214-
maxVec = std::min(maxVec, 128 / resElemTy.getIntOrFloatBitWidth());
12151201

1216-
int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth() / 8;
1202+
// Addresses to store into, one per `vecTy`.
1203+
VectorType vecTy;
1204+
SmallVector<Value> shmemAddrs;
1205+
bool ok = emitTransferBetweenRegistersAndShared(
1206+
srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207+
[&](VectorType vecTy_, Value shmemAddr) {
1208+
vecTy = vecTy_;
1209+
shmemAddrs.push_back(shmemAddr);
1210+
});
1211+
assert(ok);
1212+
1213+
int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8;
1214+
assert(llvm::isPowerOf2_32(vecBytes));
12171215
if (vecBytes < 4) {
12181216
return emitError(loc, "cp.async does not support transfers smaller than "
12191217
"4 bytes; calculated this as ")
12201218
<< vecBytes << " bytes";
12211219
}
1222-
assert(vecBytes == 16 || vecBytes == 8 || vecBytes == 4);
12231220

12241221
auto freeVarMasks = getFreeVariableMasks(srcTy);
12251222
// NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1228,63 +1225,52 @@ struct AsyncCopyGlobalToLocalOpConversion
12281225
freeVarMasks[str_attr("block")] = 0;
12291226
Value threadPred =
12301227
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
1228+
uint32_t regMask = freeVarMasks[str_attr("reg")];
12311229

1232-
auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool(llMask)](
1233-
ConversionPatternRewriter &rewriter, Location loc,
1234-
ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235-
VectorType vecTy) -> SmallVector<Value> {
1236-
assert(isa<VectorType>(vecTy));
1237-
auto *ctx = rewriter.getContext();
1238-
auto elemTy = vecTy.getElementType();
1239-
auto nBytes = vecTy.getNumElements() * elemTy.getIntOrFloatBitWidth() / 8;
1240-
assert(nBytes == 16 || nBytes == 8 || nBytes == 4);
1241-
// Tune CG and CA.
1242-
CacheModifier srcCacheModifier =
1243-
nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244-
1245-
auto structElem = vals[startIdx];
1246-
auto srcElem = b.extract_val(ptrTy, structElem, 0);
1247-
auto maskElem = b.extract_val(i1_ty, structElem, 1);
1230+
for (int i = 0; i < shmemAddrs.size(); i++) {
1231+
// It's possible that vecTy is larger than 128 bits, in which case we have
1232+
// to use multiple cp.async instructions.
1233+
int wordBytes = std::min(vecBytes, 16);
1234+
int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth();
1235+
int numWordsInVec = std::max(1, vecBytes / wordBytes);
1236+
for (int j = 0; j < numWordsInVec; j++) {
1237+
int elemIdx = i * vecTy.getNumElements() + j * wordElems;
1238+
1239+
if (!isCanonicalIndex(elemIdx, regMask)) {
1240+
continue; // Skip redundant registers
1241+
}
12481242

1249-
PTXBuilder ptxBuilder;
1250-
auto &copyAsyncOp =
1251-
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
1252-
auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddr, "r");
1253-
auto *srcOperand = ptxBuilder.newAddrOperand(srcElem, "l");
1254-
auto *copySize = ptxBuilder.newConstantOperand(nBytes);
1255-
auto *srcSize = copySize;
1256-
if (hasMask) {
1257-
// We don't use predicate in this case, setting src-size to 0
1258-
// if there's any mask. cp.async will automatically fill the
1259-
// remaining slots with 0 if cp-size > src-size.
1260-
// XXX(Keren): Always assume other = 0 for now.
1261-
// When 'other != 0' is supported, we will need to fold the
1262-
// op.getMask() and redundantDataMask() into the same predicate, the
1263-
// way it is done for LoadOp.
1264-
auto selectOp = b.select(maskElem, b.i32_val(nBytes), b.i32_val(0));
1265-
srcSize = ptxBuilder.newOperand(selectOp, "r");
1266-
}
1267-
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1268-
.maybePredicate(threadPred);
1269-
ptxBuilder.launch(rewriter, loc, void_ty(ctx));
1270-
return {};
1271-
};
1243+
// Tune CG and CA.
1244+
CacheModifier srcCacheModifier =
1245+
wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246+
assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4);
1247+
1248+
PTXBuilder ptxBuilder;
1249+
auto &copyAsyncOp =
1250+
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
1251+
auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddrs[i], "r",
1252+
/*offset=*/j * wordBytes);
1253+
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx], "l");
1254+
auto *copySize = ptxBuilder.newConstantOperand(wordBytes);
1255+
auto *srcSize = copySize;
1256+
if (op.getMask()) {
1257+
// We don't use predicate in this case, setting src-size to 0
1258+
// if there's any mask. cp.async will automatically fill the
1259+
// remaining slots with 0 if cp-size > src-size.
1260+
// XXX(Keren): Always assume other = 0 for now.
1261+
// When 'other != 0' is supported, we will need to fold the
1262+
// op.getMask() and redundantDataMask() into the same predicate, the
1263+
// way it is done for LoadOp.
1264+
auto selectOp =
1265+
b.select(maskElems[elemIdx], b.i32_val(wordBytes), b.i32_val(0));
1266+
srcSize = ptxBuilder.newOperand(selectOp, "r");
1267+
}
12721268

1273-
// %dst
1274-
auto smemObj =
1275-
getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
1276-
auto smemLayout =
1277-
ttg::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
1278-
auto cvt = srcLayout.invertAndCompose(smemLayout);
1279-
if (!cvt.isTrivialOver({str_attr("block")})) {
1280-
return emitError(loc,
1281-
"cp.async does not support non-trivial block dimension");
1269+
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1270+
.maybePredicate(threadPred);
1271+
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
1272+
}
12821273
}
1283-
cvt = cvt.sublayout(
1284-
{str_attr("register"), str_attr("lane"), str_attr("warp")},
1285-
{str_attr("offset")});
1286-
lowerLdSt(loc, ctx, cvt, vals, resElemTy, smemObj.getBase(), rewriter,
1287-
targetInfo, maxVec, emitCpAsync);
12881274

12891275
// Drop the result token.
12901276
Value zero = rewriter.create<LLVM::ConstantOp>(

0 commit comments

Comments
 (0)