Skip to content

Commit 65d0eb6

Browse files
committed
Revert "[AMD] Enable lowerLocalLdSt for AMD path (#7355)"
This reverts commit 8a5862d. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b4fe355 commit 65d0eb6

32 files changed

+1453
-643
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class TargetInfoBase {
2525
std::optional<Value> ctaId, Value val,
2626
Value pred) const = 0;
2727
virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
28-
std::optional<Value> ctaId, Type elemTy, Value pred,
29-
Operation *localLoadOp = nullptr) const = 0;
28+
std::optional<Value> ctaId, Type elemTy,
29+
Value pred) const = 0;
3030

3131
void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val,
3232
Value pred) const {

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -515,13 +515,10 @@ SmallVector<SmallVector<Value>>
515515
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
516516
Attribute layout, RankedTensorType type, bool withCTAOffset);
517517

518-
// Emits the required padding given shared memory offset
519-
// - If `offsetInBytes` is true, smemOffset and padding is assumed in bytes.
520-
// - If false, smemOffset and padding are assumed to be scaled by element
521-
// bitwidth, in which case, `bitwidth` is not used.
518+
// Emits the required padding in elements for the given shared memory offset
522519
Value emitPadding(Location loc, RewriterBase &rewriter,
523520
triton::gpu::PaddedSharedEncodingAttr layout,
524-
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
521+
Value smemOffset);
525522

526523
// Emits IR to load data from shared memory into registers, or to store data
527524
// from registers into shared memory.
@@ -549,33 +546,39 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
549546
Value laneId, Value warpId,
550547
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
551548

549+
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
550+
Type elemLlvmTy,
551+
const SharedMemoryObject &smemObj,
552+
Location loc, RewriterBase &rewriter,
553+
const TargetInfoBase &target);
554+
555+
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
556+
RankedTensorType srcTy, Type elemLlvmTy,
557+
ArrayRef<Value> srcVals,
558+
const SharedMemoryObject &smemObj, Location loc,
559+
RewriterBase &rewriter,
560+
const TargetInfoBase &target);
561+
552562
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
553563
// We might want to merge them at some point, but having to support
554564
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
555565
// Lowers to st when valArrays is empty, and to ld when it is not,
556566
// and returns the output values.
557-
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558-
// and computes a new offset (mlir::Value) by applying padding based on
559-
// shared memory layout.
560-
SmallVector<Value> lowerLdStShared(
561-
Location loc, MLIRContext *ctx, LinearLayout cvt,
562-
ArrayRef<Value> valsArray, // Input for store, output for load
563-
Type llvmElemTy, Value smemBase,
564-
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565-
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566-
const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
567+
SmallVector<Value>
568+
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
569+
ArrayRef<Value> valsArray, // Input for store, output for load
570+
Type llvmElemTy, Value smemBase, Value affineOffset,
571+
uint64_t maskSpanAffineOffset,
572+
ConversionPatternRewriter &rewriter,
573+
const TargetInfoBase &targetInfo);
567574

568575
// Lower an ld/st-like operation given a layout and a callback that creates the
569576
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
570577
// not, and returns the output values.
571-
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
572-
// and computes a new offset (mlir::Value) by applying padding based on
573-
// shared memory layout.
574578
SmallVector<Value> lowerLdSt(
575579
Location loc, MLIRContext *ctx, LinearLayout cvt,
576580
ArrayRef<Value> valsArray, // Input for store, output for load
577-
Type llvmElemTy, Value smemBase,
578-
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
581+
Type llvmElemTy, Value smemBase, Value affineOffset,
579582
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
580583
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
581584
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
@@ -589,8 +592,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
589592
ArrayRef<Value> valsArray, // Input for store, empty for load
590593
Type llvmElemTy, triton::gpu::MemDescType srcTy,
591594
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592-
const TargetInfoBase &targetInfo,
593-
Operation *localLoadOp = nullptr);
595+
const TargetInfoBase &targetInfo);
594596

595597
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
596598
RewriterBase &rewriter);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
370370
let genVerifyDecl = 1;
371371
}
372372

373-
def PaddedSharedEncodingAttr
373+
def PaddeddSharedEncodingAttr
374374
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
375375
[SharedEncodingTrait, LayoutEncodingTrait]> {
376376
let mnemonic = "padded_shared";

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
201201

202202
assert(permutedInVals.size() == tileSize * nReps);
203203
SmallVector<Value> outVals;
204-
auto noPaddingOffset = [](Value v) { return v; };
205204
auto affineOffset = b.i32_val(0);
206205
auto maskSpanAffineOffset = 0;
207206
for (int i = 0; i < nReps; ++i) {
@@ -212,13 +211,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
212211
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
213212
// Store
214213
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215-
noPaddingOffset, affineOffset, maskSpanAffineOffset,
216-
rewriter, targetInfo);
214+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
217215
b.barrier();
218216
// Load
219217
SmallVector<Value> tileOutVals = lowerLdStShared(
220-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221-
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
218+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
219+
maskSpanAffineOffset, rewriter, targetInfo);
222220
llvm::append_range(outVals, tileOutVals);
223221
}
224222

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@ using namespace mlir;
1212
using namespace mlir::triton;
1313
using namespace mlir::triton::gpu;
1414

15+
// blocked -> shared.
16+
// Swizzling in shared memory to avoid bank conflict. Normally used for
17+
// A/B operands of dots.
18+
void lowerDistributedToShared(Location loc, Value src, Value dst,
19+
Value adaptorSrc,
20+
const SharedMemoryObject &smemObj,
21+
const LLVMTypeConverter *typeConverter,
22+
ConversionPatternRewriter &rewriter,
23+
const TargetInfoBase &targetInfo) {
24+
auto srcTy = cast<RankedTensorType>(src.getType());
25+
auto dstTy = cast<MemDescType>(dst.getType());
26+
auto elemTy = typeConverter->convertType(srcTy.getElementType());
27+
28+
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
29+
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30+
targetInfo);
31+
}
32+
1533
LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
1634
MemDescType memDescTy, SharedMemoryObject smemObj,
1735
ArrayRef<Value> inVals,
@@ -21,25 +39,19 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
2139
auto regTy = cast<RankedTensorType>(regVal.getType());
2240
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
2341

42+
auto regLayout = toLinearLayout(regTy);
43+
auto sharedLayout = toLinearLayout(memDescTy);
44+
auto cvt = regLayout.invertAndCompose(sharedLayout);
45+
46+
auto kBlock = str_attr("block");
47+
// NYI. We would need to emit a map.shared::cluster instruction.
48+
if (!cvt.isTrivialOver({kBlock})) {
49+
return failure();
50+
}
2451
auto kReg = str_attr("register");
2552
auto kLane = str_attr("lane");
2653
auto kWarp = str_attr("warp");
2754
auto kOffset = str_attr("offset");
28-
auto regLayout = toLinearLayout(regTy);
29-
auto paddedLayout =
30-
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding());
31-
LinearLayout cvt = LinearLayout::empty();
32-
if (paddedLayout) {
33-
cvt = regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
34-
} else {
35-
auto sharedLayout = toLinearLayout(memDescTy);
36-
cvt = regLayout.invertAndCompose(sharedLayout);
37-
auto kBlock = str_attr("block");
38-
// NYI. We would need to emit a map.shared::cluster instruction.
39-
if (!cvt.isTrivialOver({kBlock})) {
40-
return failure();
41-
}
42-
}
4355
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
4456
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
4557
rewriter, targetInfo);
@@ -103,12 +115,25 @@ struct LocalAllocOpConversion
103115
loc, rewriter);
104116
// If there is an initial tensor, store it into the shared memory.
105117
if (op.getSrc()) {
106-
auto *ctx = op.getContext();
107-
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
108-
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
109-
inVals, typeConverter, rewriter,
110-
targetInfo))) {
111-
return failure();
118+
// [Legacy local_load/local_store]
119+
// TODO(Lezcano) We should activate this path for other targets as it's
120+
// more efficient. AFAIK The main blockers are:
121+
// - The legacy path calls localLoadOpAnnotation
122+
// - The legacy path calls llvm.load/llvm.store unconditionally, while
123+
// the AMD lowering of storeDShared does not, even when the predicate
124+
// is constant true.
125+
if (targetInfo.isCuda()) {
126+
auto *ctx = op.getContext();
127+
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
128+
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
129+
inVals, typeConverter, rewriter,
130+
targetInfo))) {
131+
return failure();
132+
}
133+
} else {
134+
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
135+
adaptor.getSrc(), smemObj, typeConverter,
136+
rewriter, targetInfo);
112137
}
113138
}
114139
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
@@ -156,31 +181,32 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
156181
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
157182
llvmElemTy, rewriter);
158183

159-
auto sharedEnc =
160-
cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding());
184+
// See [Legacy local_load/local_store]
185+
if (!targetInfo.isCuda()) {
186+
SmallVector<Value> outVals = loadSharedToDistributed(
187+
op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188+
Value result =
189+
packLLElements(loc, typeConverter, outVals, rewriter, regTy);
190+
rewriter.replaceOp(op, result);
191+
return success();
192+
}
193+
194+
auto regLayout = toLinearLayout(regTy);
195+
auto sharedLayout = toLinearLayout(memDescTy);
196+
auto cvt = regLayout.invertAndCompose(sharedLayout);
197+
auto kBlock = str_attr("block");
198+
// NYI. We would need to emit a map.shared::cluster instruction.
199+
if (!cvt.isTrivialOver({kBlock})) {
200+
return failure();
201+
}
161202
auto kReg = str_attr("register");
162203
auto kLane = str_attr("lane");
163204
auto kWarp = str_attr("warp");
164205
auto kOffset = str_attr("offset");
165-
auto regLayout = toLinearLayout(regTy);
166-
auto paddedLayout =
167-
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168-
LinearLayout cvt = LinearLayout::empty();
169-
if (paddedLayout) {
170-
cvt = regLayout.reshapeOuts({{kOffset, regLayout.getTotalOutDimSize()}});
171-
} else {
172-
auto sharedLayout = toLinearLayout(memDescTy);
173-
cvt = regLayout.invertAndCompose(sharedLayout);
174-
auto kBlock = str_attr("block");
175-
// NYI. We would need to emit a map.shared::cluster instruction.
176-
if (!cvt.isTrivialOver({kBlock})) {
177-
return failure();
178-
}
179-
}
180206
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
181207

182208
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
183-
smemObj, rewriter, targetInfo, op);
209+
smemObj, rewriter, targetInfo);
184210

185211
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
186212
rewriter.replaceOp(op, result);
@@ -217,9 +243,14 @@ struct LocalStoreOpConversion
217243
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
218244
llvmElemTy, rewriter);
219245
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
220-
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
221-
typeConverter, rewriter, targetInfo))) {
222-
return failure();
246+
if (targetInfo.isCuda()) {
247+
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
248+
typeConverter, rewriter, targetInfo))) {
249+
return failure();
250+
}
251+
} else {
252+
lowerDistributedToShared(loc, regVal, memDescVal, adaptor.getSrc(),
253+
smemObj, typeConverter, rewriter, targetInfo);
223254
}
224255

225256
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)