Skip to content

Commit 4bedc47

Browse files
authored
Revert "Revert "[AMD] Enable lowerLocalLdSt for AMD path (#7355)"" (#4848)
Signed-off-by: Anatoly Myachev <[email protected]>
2 parents 83341cf + 68a3466 commit 4bedc47

36 files changed

+663
-1471
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class TargetInfoBase {
2525
std::optional<Value> ctaId, Value val,
2626
Value pred) const = 0;
2727
virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
28-
std::optional<Value> ctaId, Type elemTy,
29-
Value pred) const = 0;
28+
std::optional<Value> ctaId, Type elemTy, Value pred,
29+
Operation *localLoadOp = nullptr) const = 0;
3030

3131
void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val,
3232
Value pred) const {

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,13 @@ SmallVector<SmallVector<Value>>
515515
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
516516
Attribute layout, RankedTensorType type, bool withCTAOffset);
517517

518-
// Emits the required padding in elements for the given shared memory offset
518+
// Emits the required padding given shared memory offset
519+
// - If `offsetInBytes` is true, smemOffset and padding is assumed in bytes.
520+
// - If false, smemOffset and padding are assumed to be scaled by element
521+
// bitwidth, in which case, `bitwidth` is not used.
519522
Value emitPadding(Location loc, RewriterBase &rewriter,
520523
triton::gpu::PaddedSharedEncodingAttr layout,
521-
Value smemOffset);
524+
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
522525

523526
// Emits IR to load data from shared memory into registers, or to store data
524527
// from registers into shared memory.
@@ -546,39 +549,33 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
546549
Value laneId, Value warpId,
547550
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
548551

549-
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
550-
Type elemLlvmTy,
551-
const SharedMemoryObject &smemObj,
552-
Location loc, RewriterBase &rewriter,
553-
const TargetInfoBase &target);
554-
555-
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
556-
RankedTensorType srcTy, Type elemLlvmTy,
557-
ArrayRef<Value> srcVals,
558-
const SharedMemoryObject &smemObj, Location loc,
559-
RewriterBase &rewriter,
560-
const TargetInfoBase &target);
561-
562552
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
563553
// We might want to merge them at some point, but having to support
564554
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
565555
// Lowers to st when valArrays is empty, and to ld when it is not,
566556
// and returns the output values.
567-
SmallVector<Value>
568-
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
569-
ArrayRef<Value> valsArray, // Input for store, output for load
570-
Type llvmElemTy, Value smemBase, Value affineOffset,
571-
uint64_t maskSpanAffineOffset,
572-
ConversionPatternRewriter &rewriter,
573-
const TargetInfoBase &targetInfo);
557+
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558+
// and computes a new offset (mlir::Value) by applying padding based on
559+
// shared memory layout.
560+
SmallVector<Value> lowerLdStShared(
561+
Location loc, MLIRContext *ctx, LinearLayout cvt,
562+
ArrayRef<Value> valsArray, // Input for store, output for load
563+
Type llvmElemTy, Value smemBase,
564+
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566+
const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
574567

575568
// Lower an ld/st-like operation given a layout and a callback that creates the
576569
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
577570
// not, and returns the output values.
571+
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
572+
// and computes a new offset (mlir::Value) by applying padding based on
573+
// shared memory layout.
578574
SmallVector<Value> lowerLdSt(
579575
Location loc, MLIRContext *ctx, LinearLayout cvt,
580576
ArrayRef<Value> valsArray, // Input for store, output for load
581-
Type llvmElemTy, Value smemBase, Value affineOffset,
577+
Type llvmElemTy, Value smemBase,
578+
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
582579
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
583580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
584581
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
@@ -592,7 +589,8 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
592589
ArrayRef<Value> valsArray, // Input for store, empty for load
593590
Type llvmElemTy, triton::gpu::MemDescType srcTy,
594591
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
595-
const TargetInfoBase &targetInfo);
592+
const TargetInfoBase &targetInfo,
593+
Operation *localLoadOp = nullptr);
596594

597595
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
598596
RewriterBase &rewriter);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
370370
let genVerifyDecl = 1;
371371
}
372372

373-
def PaddeddSharedEncodingAttr
373+
def PaddedSharedEncodingAttr
374374
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
375375
[SharedEncodingTrait, LayoutEncodingTrait]> {
376376
let mnemonic = "padded_shared";

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
201201

202202
assert(permutedInVals.size() == tileSize * nReps);
203203
SmallVector<Value> outVals;
204+
auto noPaddingOffset = [](Value v) { return v; };
204205
auto affineOffset = b.i32_val(0);
205206
auto maskSpanAffineOffset = 0;
206207
for (int i = 0; i < nReps; ++i) {
@@ -211,12 +212,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
211212
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
212213
// Store
213214
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
214-
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
215+
noPaddingOffset, affineOffset, maskSpanAffineOffset,
216+
rewriter, targetInfo);
215217
b.barrier();
216218
// Load
217219
SmallVector<Value> tileOutVals = lowerLdStShared(
218-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
219-
maskSpanAffineOffset, rewriter, targetInfo);
220+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
220222
llvm::append_range(outVals, tileOutVals);
221223
}
222224

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 42 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,6 @@ using namespace mlir;
1212
using namespace mlir::triton;
1313
using namespace mlir::triton::gpu;
1414

15-
// blocked -> shared.
16-
// Swizzling in shared memory to avoid bank conflict. Normally used for
17-
// A/B operands of dots.
18-
void lowerDistributedToShared(Location loc, Value src, Value dst,
19-
Value adaptorSrc,
20-
const SharedMemoryObject &smemObj,
21-
const LLVMTypeConverter *typeConverter,
22-
ConversionPatternRewriter &rewriter,
23-
const TargetInfoBase &targetInfo) {
24-
auto srcTy = cast<RankedTensorType>(src.getType());
25-
auto dstTy = cast<MemDescType>(dst.getType());
26-
auto elemTy = typeConverter->convertType(srcTy.getElementType());
27-
28-
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
29-
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30-
targetInfo);
31-
}
32-
3315
LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
3416
MemDescType memDescTy, SharedMemoryObject smemObj,
3517
ArrayRef<Value> inVals,
@@ -39,19 +21,25 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
3921
auto regTy = cast<RankedTensorType>(regVal.getType());
4022
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
4123

42-
auto regLayout = toLinearLayout(regTy);
43-
auto sharedLayout = toLinearLayout(memDescTy);
44-
auto cvt = regLayout.invertAndCompose(sharedLayout);
45-
46-
auto kBlock = str_attr("block");
47-
// NYI. We would need to emit a map.shared::cluster instruction.
48-
if (!cvt.isTrivialOver({kBlock})) {
49-
return failure();
50-
}
5124
auto kReg = str_attr("register");
5225
auto kLane = str_attr("lane");
5326
auto kWarp = str_attr("warp");
5427
auto kOffset = str_attr("offset");
28+
auto regLayout = toLinearLayout(regTy);
29+
auto paddedLayout =
30+
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding());
31+
LinearLayout cvt = LinearLayout::empty();
32+
if (paddedLayout) {
33+
cvt = regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
34+
} else {
35+
auto sharedLayout = toLinearLayout(memDescTy);
36+
cvt = regLayout.invertAndCompose(sharedLayout);
37+
auto kBlock = str_attr("block");
38+
// NYI. We would need to emit a map.shared::cluster instruction.
39+
if (!cvt.isTrivialOver({kBlock})) {
40+
return failure();
41+
}
42+
}
5543
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
5644
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
5745
rewriter, targetInfo);
@@ -115,25 +103,12 @@ struct LocalAllocOpConversion
115103
loc, rewriter);
116104
// If there is an initial tensor, store it into the shared memory.
117105
if (op.getSrc()) {
118-
// [Legacy local_load/local_store]
119-
// TODO(Lezcano) We should activate this path for other targets as it's
120-
// more efficient. AFAIK The main blockers are:
121-
// - The legacy path calls localLoadOpAnnotation
122-
// - The legacy path calls llvm.load/llvm.store unconditionally, while
123-
// the AMD lowering of storeDShared does not, even when the predicate
124-
// is constant true.
125-
if (targetInfo.isCuda()) {
126-
auto *ctx = op.getContext();
127-
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
128-
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
129-
inVals, typeConverter, rewriter,
130-
targetInfo))) {
131-
return failure();
132-
}
133-
} else {
134-
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
135-
adaptor.getSrc(), smemObj, typeConverter,
136-
rewriter, targetInfo);
106+
auto *ctx = op.getContext();
107+
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
108+
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
109+
inVals, typeConverter, rewriter,
110+
targetInfo))) {
111+
return failure();
137112
}
138113
}
139114
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
@@ -181,32 +156,31 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
181156
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
182157
llvmElemTy, rewriter);
183158

184-
// See [Legacy local_load/local_store]
185-
if (!targetInfo.isCuda()) {
186-
SmallVector<Value> outVals = loadSharedToDistributed(
187-
op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188-
Value result =
189-
packLLElements(loc, typeConverter, outVals, rewriter, regTy);
190-
rewriter.replaceOp(op, result);
191-
return success();
192-
}
193-
194-
auto regLayout = toLinearLayout(regTy);
195-
auto sharedLayout = toLinearLayout(memDescTy);
196-
auto cvt = regLayout.invertAndCompose(sharedLayout);
197-
auto kBlock = str_attr("block");
198-
// NYI. We would need to emit a map.shared::cluster instruction.
199-
if (!cvt.isTrivialOver({kBlock})) {
200-
return failure();
201-
}
159+
auto sharedEnc =
160+
cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding());
202161
auto kReg = str_attr("register");
203162
auto kLane = str_attr("lane");
204163
auto kWarp = str_attr("warp");
205164
auto kOffset = str_attr("offset");
165+
auto regLayout = toLinearLayout(regTy);
166+
auto paddedLayout =
167+
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168+
LinearLayout cvt = LinearLayout::empty();
169+
if (paddedLayout) {
170+
cvt = regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
171+
} else {
172+
auto sharedLayout = toLinearLayout(memDescTy);
173+
cvt = regLayout.invertAndCompose(sharedLayout);
174+
auto kBlock = str_attr("block");
175+
// NYI. We would need to emit a map.shared::cluster instruction.
176+
if (!cvt.isTrivialOver({kBlock})) {
177+
return failure();
178+
}
179+
}
206180
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
207181

208182
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209-
smemObj, rewriter, targetInfo);
183+
smemObj, rewriter, targetInfo, op);
210184

211185
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
212186
rewriter.replaceOp(op, result);
@@ -243,14 +217,9 @@ struct LocalStoreOpConversion
243217
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
244218
llvmElemTy, rewriter);
245219
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
246-
if (targetInfo.isCuda()) {
247-
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
248-
typeConverter, rewriter, targetInfo))) {
249-
return failure();
250-
}
251-
} else {
252-
lowerDistributedToShared(loc, regVal, memDescVal, adaptor.getSrc(),
253-
smemObj, typeConverter, rewriter, targetInfo);
220+
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
221+
typeConverter, rewriter, targetInfo))) {
222+
return failure();
254223
}
255224

256225
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)