Skip to content

Commit ea4bdaf

Browse files
authored
[BACKEND] Apply padding when lowering ttg.memdesc_index (#7696)
triton-lang/triton#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what triton-lang/triton#7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
1 parent 1f797a5 commit ea4bdaf

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,15 @@ struct MemDescIndexOpConversion
493493
auto prevOffsets = smemObj.getOffsets();
494494
SmallVector<Value> offsetVals(prevOffsets.end() - dstTy.getRank(),
495495
prevOffsets.end());
496+
497+
// Apply padding based on the amount we move the base ptr
498+
if (auto padEnc = dyn_cast<PaddedSharedEncodingAttr>(dstTy.getEncoding())) {
499+
auto bitwidth = dstTy.getElementTypeBitWidth();
500+
Value padOffset = emitPadding(loc, rewriter, padEnc, bitwidth, offset,
501+
/*offsetInBytes=*/false);
502+
offset = b.add(offset, padOffset);
503+
}
504+
496505
// Advance the pointer and keep the opOffsets as the new shape
497506
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
498507
llvmElemTy, offsetVals);

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -412,34 +412,30 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
412412

413413
// -----
414414

415-
// CHECK-LABEL: padded_shared_layout_subview
415+
// GFX950-LABEL: padded_shared_layout_subview
416416
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
417-
#shared = #ttg.padded_shared<[128:+4, 256:+8] {order = [1, 0]}>
417+
#shared = #ttg.padded_shared<[128:+4] {order = [1, 0]}>
418418
#smem = #ttg.shared_memory
419419
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
420420
tt.func @padded_shared_layout_subview(%arg0: !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable>) {
421421
%c0_i32 = arith.constant 0 : i32
422422
%c1_i32 = arith.constant 1 : i32
423-
// Skip two constants from the stride calculation
423+
// Skip three constants from the stride calculation
424+
// GFX950: llvm.mlir.constant
425+
// GFX950: llvm.mlir.constant
426+
// GFX950: llvm.mlir.constant
424427

425-
// CHECK-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
426-
// CHECK-DAG: %[[CST3:.+]] = llvm.mlir.constant(3 : i32)
427-
// CHECK-DAG: %[[CST4:.+]] = llvm.mlir.constant(4 : i32)
428-
// CHECK-DAG: %[[CST8:.+]] = llvm.mlir.constant(8 : i32)
429-
// CHECK-DAG: %[[CST9:.+]] = llvm.mlir.constant(9 : i32)
428+
// GFX950-DAG: %[[CST0:.+]] = llvm.mlir.constant(0 : i32)
429+
// GFX950-DAG: %[[CST7:.+]] = llvm.mlir.constant(7 : i32)
430+
// GFX950-DAG: %[[CST2:.+]] = llvm.mlir.constant(2 : i32)
430431

431-
// CHECK: %[[SHR0:.+]] = llvm.ashr %[[ADD:.+]], %[[CST8]] : i32
432-
// CHECK-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST3]] : i32
433-
// CHECK-NEXT: %[[ADD0:.+]] = llvm.add %[[SHL0]], %[[CST0]] : i32
434-
// CHECK-NEXT: %[[SHR1:.+]] = llvm.ashr %[[ADD]], %[[CST9]] : i32
435-
// CHECK-NEXT: %[[SHL1:.+]] = llvm.shl %[[SHR1]], %[[CST4]] : i32
436-
// CHECK-NEXT: %[[ADD1:.+]] = llvm.add %[[ADD0]], %[[SHL1]] : i32
437-
// CHECK-NEXT: %[[ADD2:.+]] = llvm.add %[[ADD]], %[[ADD1]] : i32
438-
// CHECK: llvm.getelementptr inbounds %{{.+}}[%[[ADD2]]]
432+
// GFX950: %[[SHR0:.+]] = llvm.ashr %[[ADD:.+]], %[[CST7]] : i32
433+
// GFX950-NEXT: %[[SHL0:.+]] = llvm.shl %[[SHR0]], %[[CST2]] : i32
434+
// GFX950-NEXT: %[[ADD1:.+]] = llvm.add %[[CST0]], %[[SHL0]] : i32
435+
// GFX950-NEXT: %[[ADD2:.+]] = llvm.add %[[ADD]], %[[ADD1]] : i32
436+
// GFX950: llvm.getelementptr %{{.+}}[%[[ADD2]]]
439437

440438
%1 = ttg.memdesc_index %arg0, %c1_i32 : !ttg.memdesc<2x64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
441-
%2 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
442-
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
443439
tt.return
444440
}
445441
}

0 commit comments

Comments
 (0)