Skip to content

Commit dcf41f6

Browse files
authored
[BACKEND] Share ld/st.shared lowering between convert_layout and local_load/store (#7248)
We recently heavily optimised the ld/st.shared lowering for convert_layout. Let's reuse it in other ops for better perf and to have just one source of truth. For now we just activate it for NVIDIA's path. Fixing it for AMD should be easy (and beneficial) but I don't have access to an AMD GPU so... cc @antiagainst
1 parent 40f7163 commit dcf41f6

File tree

4 files changed

+148
-34
lines changed

4 files changed

+148
-34
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,15 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
574574
ConversionPatternRewriter &rewriter,
575575
const TargetInfoBase &targetInfo);
576576

577+
// Lower local_load/local_store via ld.shared/st.shared
578+
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
579+
// Map from registers to offset
580+
LinearLayout cvt, ArrayRef<Value> valsArray,
581+
// Input for store, output for load
582+
Type llvmElemTy, Value smemBase,
583+
ConversionPatternRewriter &rewriter,
584+
const TargetInfoBase &targetInfo);
585+
577586
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
578587
RewriterBase &rewriter);
579588

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,36 @@ void lowerDistributedToShared(
2929
targetInfo, llvmOpCount);
3030
}
3131

32+
LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
33+
MemDescType memDescTy, SharedMemoryObject smemObj,
34+
ArrayRef<Value> inVals,
35+
const LLVMTypeConverter *typeConverter,
36+
ConversionPatternRewriter &rewriter,
37+
const TargetInfoBase &targetInfo) {
38+
auto regTy = cast<RankedTensorType>(regVal.getType());
39+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
40+
41+
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
42+
auto sharedLayout =
43+
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
44+
auto cvt = regLayout.invertAndCompose(sharedLayout);
45+
46+
auto kBlock = str_attr("block");
47+
// NYI. We would need to emit a map.shared::cluster instruction.
48+
if (!cvt.isTrivialOver({kBlock})) {
49+
return failure();
50+
}
51+
auto kReg = str_attr("register");
52+
auto kLane = str_attr("lane");
53+
auto kWarp = str_attr("warp");
54+
auto kOffset = str_attr("offset");
55+
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
56+
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, smemObj.getBase(), rewriter,
57+
targetInfo);
58+
59+
return success();
60+
}
61+
3262
struct GlobalScratchAllocOpConversion
3363
: public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
3464
const TargetInfoBase *targetInfo;
@@ -77,17 +107,34 @@ struct LocalAllocOpConversion
77107
Location loc = op->getLoc();
78108
Value smemBase =
79109
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
80-
auto resultTy = cast<MemDescType>(op.getType());
110+
auto memDescTy = cast<MemDescType>(op.getType());
81111
auto typeConverter = getTypeConverter();
82112

83-
auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
84-
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(),
113+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
114+
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, memDescTy.getRank(),
85115
loc, rewriter);
86116
// If there is an initial tensor, store it into the shared memory.
87117
if (op.getSrc()) {
88-
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
89-
adaptor.getSrc(), smemObj, typeConverter,
90-
rewriter, targetInfo);
118+
// [Legacy local_load/local_store]
119+
// TODO(Lezcano) We should activate this path for other targets as it's
120+
// more efficient. AFAIK The main blockers are:
121+
// - The legacy path calls localLoadOpAnnotation
122+
// - The legacy path calls llvm.load/llvm.store unconditionally, while
123+
// the AMD lowering of storeDShared does not, even when the predicate
124+
// is constant true.
125+
if (targetInfo.isCuda()) {
126+
auto *ctx = op.getContext();
127+
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
128+
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
129+
inVals, typeConverter, rewriter,
130+
targetInfo))) {
131+
return failure();
132+
}
133+
} else {
134+
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
135+
adaptor.getSrc(), smemObj, typeConverter,
136+
rewriter, targetInfo);
137+
}
91138
}
92139
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
93140
rewriter.replaceOp(op, retVal);
@@ -122,27 +169,48 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
122169
LogicalResult
123170
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
124171
ConversionPatternRewriter &rewriter) const override {
125-
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
126-
}
127-
128-
private:
129-
LogicalResult
130-
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
131-
const LLVMTypeConverter *typeConverter,
132-
ConversionPatternRewriter &rewriter) const {
133172
auto loc = op.getLoc();
134-
auto srcTy = op.getSrc().getType();
135-
auto dstTy = op.getResult().getType();
173+
auto *ctx = op.getContext();
174+
auto memDescVal = op.getSrc();
175+
auto regVal = op.getResult();
176+
auto memDescTy = cast<MemDescType>(memDescVal.getType());
177+
auto regTy = cast<RankedTensorType>(regVal.getType());
178+
auto typeConverter = getTypeConverter();
136179

137180
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
138181
loc, adaptor.getSrc(),
139-
typeConverter->convertType(srcTy.getElementType()), rewriter);
140-
auto elemLlvmTy = typeConverter->convertType(dstTy.getElementType());
182+
typeConverter->convertType(memDescTy.getElementType()), rewriter);
183+
auto llvmElemTy = typeConverter->convertType(regTy.getElementType());
184+
185+
// See [Legacy local_load/local_store]
186+
if (!targetInfo.isCuda()) {
187+
SmallVector<Value> outVals = loadSharedToDistributed(
188+
op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
189+
Value result =
190+
packLLElements(loc, typeConverter, outVals, rewriter, regTy);
191+
rewriter.replaceOp(op, result);
192+
return success();
193+
}
141194

142-
SmallVector<Value> outVals = loadSharedToDistributed(
143-
op, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
195+
auto regLayout = toLinearLayout(regTy.getShape(), regTy.getEncoding());
196+
auto sharedLayout =
197+
toLinearLayout(memDescTy.getShape(), memDescTy.getEncoding());
198+
auto cvt = regLayout.invertAndCompose(sharedLayout);
199+
auto kBlock = str_attr("block");
200+
// NYI. We would need to emit a map.shared::cluster instruction.
201+
if (!cvt.isTrivialOver({kBlock})) {
202+
return failure();
203+
}
204+
auto kReg = str_attr("register");
205+
auto kLane = str_attr("lane");
206+
auto kWarp = str_attr("warp");
207+
auto kOffset = str_attr("offset");
208+
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
209+
210+
auto outVals = lowerLocalLdSt(op.getLoc(), ctx, cvt, {}, llvmElemTy,
211+
smemObj.getBase(), rewriter, targetInfo);
144212

145-
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
213+
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
146214
rewriter.replaceOp(op, result);
147215

148216
return success();
@@ -167,20 +235,30 @@ struct LocalStoreOpConversion
167235
LogicalResult
168236
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
169237
ConversionPatternRewriter &rewriter) const override {
238+
auto loc = op.getLoc();
239+
auto *ctx = op.getContext();
240+
Value regVal = op.getSrc();
170241
Value memDescVal = op.getDst();
171-
auto llvmElemTy =
172-
getTypeConverter()->convertType(op.getDst().getType().getElementType());
173-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
174-
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
175-
242+
auto typeConverter = getTypeConverter();
243+
auto memDescTy = cast<MemDescType>(memDescVal.getType());
244+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
245+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
246+
llvmElemTy, rewriter);
247+
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
176248
std::pair<size_t, Type> llvmOpCount;
177-
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
178-
adaptor.getSrc(), smemObj, getTypeConverter(),
179-
rewriter, targetInfo, &llvmOpCount);
249+
if (targetInfo.isCuda()) {
250+
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
251+
typeConverter, rewriter, targetInfo))) {
252+
return failure();
253+
}
254+
} else {
255+
lowerDistributedToShared(loc, regVal, memDescVal, adaptor.getSrc(),
256+
smemObj, typeConverter, rewriter, targetInfo,
257+
&llvmOpCount);
258+
}
180259

181260
targetInfo.localStoreOpAnnotation(op, llvmOpCount.first,
182261
llvmOpCount.second);
183-
184262
rewriter.eraseOp(op);
185263
return success();
186264
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,35 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
588588
return outVals;
589589
}
590590

591+
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
592+
LinearLayout cvt, ArrayRef<Value> valsArray,
593+
// Input for store, output for load
594+
Type llvmElemTy, Value smemBase,
595+
ConversionPatternRewriter &rewriter,
596+
const TargetInfoBase &targetInfo) {
597+
assert(cvt.getNumOutDims() == 1);
598+
assert(*cvt.getOutDimNames().begin() == str_attr("offset"));
599+
auto isStore = !valsArray.empty();
600+
// Remove broadcasting in the registers
601+
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(cvt);
602+
if (!removeBroadcastSrc.isIdentity()) {
603+
auto prmtCvt = removeBroadcastSrc.apply(cvt);
604+
auto inVals = to_vector(valsArray);
605+
if (isStore) {
606+
inVals = removeBroadcastSrc.apply(inVals);
607+
}
608+
auto outVals = lowerLdStShared(loc, ctx, prmtCvt, inVals, llvmElemTy,
609+
smemBase, rewriter, targetInfo);
610+
if (!isStore) {
611+
outVals = broadcastAs(outVals, cvt);
612+
}
613+
return outVals;
614+
}
615+
616+
return lowerLdStShared(loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
617+
rewriter, targetInfo);
618+
}
619+
591620
bool emitTransferBetweenRegistersAndShared(
592621
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
593622
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,7 +1918,7 @@ module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-
19181918
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
19191919
// CHECK-LABEL: @vectorize_shmem_load
19201920
// CHECK: llvm.load
1921-
// CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8>
1921+
// CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<2xi32>
19221922
// CHECK-NOT: llvm.load
19231923
tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #smem>) {
19241924
%0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #smem> -> tensor<16x16xi8, #blocked>
@@ -1933,9 +1933,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
19331933
#smem = #ttg.shared_memory
19341934
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
19351935
// CHECK-LABEL: @vectorize_shmem_store
1936-
// CHECK: llvm.store
1937-
// CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3>
1938-
// CHECK-NOT: llvm.store
1936+
// CHECK-COUNT-4: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<3>
19391937
tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) {
19401938
%0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #smem>
19411939
tt.return

0 commit comments

Comments
 (0)