Skip to content

Commit e2ddc50

Browse files
authored
[BACKEND] Limit vec size with minimum padding interval (triton-lang#8050)
When lowering we need to limit the vec size based on the minimum interval. This is already done in the old lowering in `emitTransferBetweenRegistersAndShared`.
1 parent 0a93c96 commit e2ddc50

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
569569
std::function<Value(Value)> calcPaddedOffset,
570570
Value affineOffset, uint64_t maskSpanAffineOffset,
571571
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
572+
std::optional<int> maybeMaxVecElems = {},
572573
Operation *localLoadOp = nullptr);
573574

574575
// Lower an ld/st-like operation given a layout and a callback that creates the

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
567567
std::function<Value(Value)> calcPaddedOffset,
568568
Value affineOffset, uint64_t maskSpanAffineOffset,
569569
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
570-
Operation *localLoadOp) {
570+
std::optional<int> maybeMaxVecElems, Operation *localLoadOp) {
571571

572572
bool isStore = !valsArray.empty();
573573
auto b = TritonLLVMOpBuilder(loc, rewriter);
@@ -593,7 +593,7 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
593593
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
594594
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
595595
calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
596-
warpId, rewriter, targetInfo, {}, emitLdSt);
596+
warpId, rewriter, targetInfo, maybeMaxVecElems, emitLdSt);
597597
}
598598

599599
SmallVector<Value> lowerLdSt(
@@ -728,9 +728,17 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
728728
}
729729
auto affineOffset = smemObj.getShmemOffset(loc, rewriter, srcTy);
730730
auto maskSpanAffineOffset = smemObj.getMaskSpanOffsets(srcTy);
731-
return lowerLdStShared(
732-
loc, ctx, cvt, valsArray, llvmElemTy, smemObj.getBase(), calcPaddedOffset,
733-
affineOffset, maskSpanAffineOffset, rewriter, targetInfo, localLoadOp);
731+
732+
std::optional<int> maybeMaxVecElems;
733+
if (auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
734+
srcTy.getEncoding())) {
735+
maybeMaxVecElems = paddedEnc.getMinInterval();
736+
}
737+
738+
return lowerLdStShared(loc, ctx, cvt, valsArray, llvmElemTy,
739+
smemObj.getBase(), calcPaddedOffset, affineOffset,
740+
maskSpanAffineOffset, rewriter, targetInfo,
741+
maybeMaxVecElems, localLoadOp);
734742
}
735743

736744
bool emitTransferBetweenRegistersAndShared(

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,36 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
480480

481481
// -----
482482

483+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
484+
#shared = #ttg.padded_shared<[4:+4] {offset=[[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0], [2, 0], [4, 0], [8, 0]], block=[]}>
485+
#smem = #ttg.shared_memory
486+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
487+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
488+
// CHECK-LABEL: padded_shared_layout_vectorization_limited_by_min_interval
489+
tt.func @padded_shared_layout_vectorization_limited_by_min_interval(%arg0: tensor<16x32xf16, #blocked>) {
490+
// CHECK-NOT: llvm.store
491+
// CHECK: llvm.store {{.*}} : vector<4xf16>
492+
// CHECK: llvm.store {{.*}} : vector<4xf16>
493+
// CHECK-NOT: llvm.store
494+
%0 = ttg.local_alloc %arg0 : (tensor<16x32xf16, #blocked>) -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
495+
496+
// CHECK-NOT: llvm.load
497+
// CHECK: llvm.load {{.*}} !llvm.ptr<3> -> vector<4xf16>
498+
// CHECK: llvm.load {{.*}} !llvm.ptr<3> -> vector<4xf16>
499+
// CHECK-NOT: llvm.load
500+
%1 = ttg.local_load %0: !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 16x32> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
501+
502+
// CHECK-NOT: llvm.store
503+
// CHECK: llvm.store {{.*}} : vector<4xf16>
504+
// CHECK: llvm.store {{.*}} : vector<4xf16>
505+
// CHECK-NOT: llvm.store
506+
ttg.local_store %1, %0 : tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
507+
tt.return
508+
}
509+
}
510+
511+
// -----
512+
483513
// GFX950-LABEL: reduce_32x32
484514
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
485515
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {

0 commit comments

Comments
 (0)