Skip to content

Commit 6a38bee

Browse files
authored
[BACKEND] Fix subview padding for PaddedSharedEncoding (#7404)
For padded layouts introduced by triton-lang/triton#7212 we need to add the padding to the base ptr of the resulting subview.
1 parent 7a4cfe7 commit 6a38bee

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,11 @@ SmallVector<SmallVector<Value>>
516516
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
517517
Attribute layout, RankedTensorType type, bool withCTAOffset);
518518

519+
// Emits the required padding in elements for the given shared memory offset
520+
Value emitPadding(Location loc, RewriterBase &rewriter,
521+
triton::gpu::PaddedSharedEncodingAttr layout,
522+
Value smemOffset);
523+
519524
// Emits IR to load data from shared memory into registers, or to store data
520525
// from registers into shared memory.
521526
//

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,21 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
398398
return ret;
399399
}
400400

401+
Value emitPadding(Location loc, RewriterBase &rewriter,
402+
triton::gpu::PaddedSharedEncodingAttr layout,
403+
Value smemOffset) {
404+
TritonLLVMOpBuilder b(loc, rewriter);
405+
406+
Value padOffset = b.i32_val(0);
407+
for (auto [interval, padding] :
408+
llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) {
409+
Value iVal = b.i32_val(llvm::Log2_32(interval));
410+
Value pVal = b.i32_val(llvm::Log2_32(padding));
411+
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
412+
}
413+
return padOffset;
414+
}
415+
401416
namespace {
402417

403418
Value getSmemVecAddr(const LinearLayout &regLayout,
@@ -488,13 +503,7 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
488503
if (auto paddedLayout =
489504
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc)) {
490505
// Apply the offset needed for padding.
491-
Value padOffset = b.i32_val(0);
492-
for (auto [interval, padding] : llvm::zip_equal(
493-
paddedLayout.getIntervals(), paddedLayout.getPaddings())) {
494-
Value iVal = b.i32_val(llvm::Log2_32(interval));
495-
Value pVal = b.i32_val(llvm::Log2_32(padding));
496-
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
497-
}
506+
Value padOffset = emitPadding(loc, rewriter, paddedLayout, smemOffset);
498507
smemOffset = b.add(smemOffset, padOffset);
499508
}
500509
} else { // Case 2 -> rank-reduced swizzling

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,13 @@ struct MemDescSubviewOpConversion
513513
.second;
514514
}
515515

516+
if (auto paddedLayout = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
517+
srcTy.getEncoding())) {
518+
// Apply padding based on the computed offset
519+
Value padOffset = emitPadding(loc, rewriter, paddedLayout, offset);
520+
offset = b.add(offset, padOffset);
521+
}
522+
516523
auto base = smemObj.getBase();
517524
auto elemPtrTy = base.getType();
518525
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,40 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
413413

414414
// -----
415415

416+
// CHECK-LABEL: padded_shared_layout_subview
417+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
418+
#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}>
419+
#smem = #ttg.shared_memory
420+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
421+
tt.func @padded_shared_layout_subview(%arg0: !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable>) {
422+
%c0_i32 = arith.constant 0 : i32
423+
%c1_i32 = arith.constant 1 : i32
424+
// Skip two constants from the stride calculation
425+
426+
// CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
427+
// CHECK-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32)
428+
// CHECK-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32)
429+
// CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32)
430+
// CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32)
431+
432+
// CHECK: %[[SHR0:.+]] = llvm.ashr %[[XOR:.+]], %[[CST7]] : i32
433+
// CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32
434+
// CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32
435+
// CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[XOR]], %[[CST8]] : i32
436+
// CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST3]] : i32
437+
// CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32
438+
// CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[XOR]], %[[ADD1]] : i32
439+
// CHECK-NEXT: llvm.getelementptr %{{.+}}[%[[ADD2]]]
440+
441+
%1 = ttg.memdesc_subview %arg0[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
442+
%2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
443+
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
444+
tt.return
445+
}
446+
}
447+
448+
// -----
449+
416450
// GFX950-LABEL: reduce_32x32
417451
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
418452
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {

0 commit comments

Comments
 (0)