Skip to content

Commit 9410804

Browse files
alefimov-amdbinarmanlezcano
authored
[Backend] Support padded shared in MemDescSubsliceOp (#7944)
This PR supports padded layout in MemDescSubsliceOp and adds few related tests. --------- Co-authored-by: Alexander Efimov <[email protected]> Co-authored-by: Mario Lezcano Casado <[email protected]>
1 parent f902558 commit 9410804

File tree

9 files changed

+137
-42
lines changed

9 files changed

+137
-42
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,14 @@ SharedMemoryObject::getMaskSpanOffsets(triton::gpu::MemDescType srcTy) {
11581158
if (allocShape == shape) {
11591159
return 0;
11601160
}
1161+
if (auto paddedEncoding = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
1162+
srcTy.getEncoding())) {
1163+
// Mask is used in fusion of constant part of memory operation address as
1164+
// immediate operand. Padded layout has additional address computations
1165+
// between main offset computation and actual memory access, which breaks
1166+
// constand fusing. Full mask disables this optimization.
1167+
return ~uint64_t(0);
1168+
}
11611169
auto totalLl = triton::gpu::toLinearLayout(allocShape, srcTy.getEncoding());
11621170
auto dimNames = standardOutDimNames(ctx, shape.size());
11631171
// Remove the kBlock dimension
@@ -1194,14 +1202,15 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
11941202
return b.i32_val(0);
11951203
}
11961204

1205+
LinearLayout ll;
11971206
// We return the offset without the padding. The padding will be added in the
11981207
// lowering
11991208
if (auto paddedSharedEncoding =
12001209
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
12011210
srcTy.getEncoding())) {
1202-
auto allocShape64 = srcTy.getAllocShape();
1203-
SmallVector<unsigned> allocShape(allocShape64.begin(), allocShape64.end());
1204-
return LLVM::linearize(rewriter, loc, offsets, allocShape);
1211+
ll = paddedSharedEncoding.getLinearComponent();
1212+
} else {
1213+
ll = triton::gpu::toLinearLayout(srcTy);
12051214
}
12061215

12071216
auto dimNames = standardOutDimNames(ctx, offsets.size());
@@ -1210,7 +1219,6 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
12101219
logicalOffsets.push_back({dim, offset});
12111220
}
12121221

1213-
LinearLayout ll = triton::gpu::toLinearLayout(srcTy);
12141222
ll = ll.sublayout({str_attr("offset")}, dimNames);
12151223
auto offset =
12161224
applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0].second;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,12 @@ Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) {
18091809
}
18101810
}
18111811

1812+
if (order.size() != shape.size()) {
1813+
parser.emitError(parser.getCurrentLocation(),
1814+
"Mismatch of shape and order ranks in padded layout");
1815+
return {};
1816+
}
1817+
18121818
// Create identity mapping based on shape and order
18131819
auto kOffset = StringAttr::get(parser.getContext(), "offset");
18141820
maybeLL = identityStandardND(kOffset, shape, order);

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,16 @@ LogicalResult MemDescSubsliceOp::verify() {
815815
}
816816

817817
auto ctx = getContext();
818-
auto ll = triton::gpu::toLinearLayout(srcTy);
818+
LinearLayout ll;
819+
if (auto paddedEncoding = dyn_cast<PaddedSharedEncodingAttr>(srcEnc)) {
820+
if (paddedEncoding.getRank() < srcTy.getRank()) {
821+
return emitError("SubSlice of low rank PaddedSharedEncoding from higher "
822+
"rank tensors is not supported yet");
823+
}
824+
ll = paddedEncoding.getLinearComponent();
825+
} else {
826+
ll = triton::gpu::toLinearLayout(srcTy);
827+
}
819828
// NYI: We don't support non-trivial block dimension for now.
820829
auto kBlock = mlir::StringAttr::get(getContext(), "block");
821830
if (ll.getInDimSize(kBlock) != 1) {

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,29 +177,19 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
177177
<< "the shape size when pipelining.";
178178
}
179179

180-
// Subslices are not yet implemented
181-
auto subsliceAllocSize =
182-
allocShape.drop_front(allocShape.size() - shape.size());
183-
for (auto [allocDim, shapeDim] : llvm::zip(shape, subsliceAllocSize)) {
184-
if (allocDim != shapeDim) {
185-
return emitError() << "Subslices with padded encodings are not yet "
186-
<< "implemented.";
187-
}
188-
}
189-
190180
// Ensure linear component's outDims match the alloc size ignoring
191181
// pipelining dimension
192182
auto outDims = standardOutDimNames(ctx, rank);
193183
const auto &ll = enc.getLinearComponent();
194-
auto expectedShape = shape;
195-
if (shape.size() == allocShape.size() && shape.size() == rank + 1)
184+
auto expectedShape = allocShape;
185+
if (rank == allocShape.size() - 1)
196186
expectedShape = expectedShape.drop_front(1);
197187

198188
for (auto d = 0; d < rank; d++) {
199189
if (ll.getOutDimSize(outDims[d]) != expectedShape[d]) {
200190
return emitError() << "Mismatch in expected shape for dimension " << d
201-
<< ". Expected: " << ll.getOutDimSize(outDims[d])
202-
<< ", got: " << expectedShape[d];
191+
<< ". Expected: " << expectedShape[d]
192+
<< ", got: " << ll.getOutDimSize(outDims[d]);
203193
}
204194
}
205195
}

python/test/gluon/test_core.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,58 @@ def kernel(x, y):
897897

898898
compiled_kernel = kernel.warmup(input, output, grid=(1, ))
899899
assert compiled_kernel.asm["ttgir"].count("tt.func private") == 0
900+
901+
902+
@pytest.mark.parametrize("interval_pairs", [[[32, 4]], [[16, 4]], [[16, 4], [64, 8]]])
903+
@pytest.mark.parametrize(
904+
"shared_layout",
905+
[{"order": [0, 1]}, {"order": [1, 0]},
906+
{"offsets": [[0, 1], [0, 2], [0, 8], [0, 4], [0, 16], [0, 32], [2, 0], [1, 0], [4, 0], [8, 0], [16, 0], [32, 0]]}])
907+
@pytest.mark.parametrize("slice_m_offset, slice_n_offset, slice_m, slice_n", [(48, 16, 16, 16), (32, 48, 32, 16),
908+
(48, 32, 16, 32)])
909+
def test_padded_shared_layout_subslice(interval_pairs, shared_layout, slice_m_offset, slice_n_offset, slice_m, slice_n):
910+
m = 64
911+
n = 64
912+
num_warps = 1
913+
num_warps_cst = ttgl.constexpr(num_warps)
914+
warp_size_cst = ttgl.constexpr(THREADS_PER_WARP)
915+
916+
shape = [m, n]
917+
if "order" in shared_layout:
918+
order = shared_layout["order"]
919+
smem_layout = ttgl.constexpr(ttgl.PaddedSharedLayout.with_identity_for(interval_pairs, shape, order))
920+
elif "offsets" in shared_layout:
921+
offsets = shared_layout["offsets"]
922+
blocks = []
923+
smem_layout = ttgl.constexpr(ttgl.PaddedSharedLayout(interval_pairs, offsets, blocks, shape))
924+
925+
@gluon.jit
926+
def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET: ttgl.constexpr,
927+
SLICE_N_OFFSET: ttgl.constexpr, SLICE_M: ttgl.constexpr, SLICE_N: ttgl.constexpr):
928+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [warp_size_cst, 1], [1, num_warps_cst], [1, 0])
929+
offs_m_load = ttgl.arange(0, M, ttgl.SliceLayout(1, blocked))
930+
offs_n_load = ttgl.arange(0, N, ttgl.SliceLayout(0, blocked))
931+
in_offs = offs_m_load[:, None] * N + offs_n_load[None, :]
932+
933+
in_data = ttgl.load(in_ptr + in_offs)
934+
935+
smem = ttgl.allocate_shared_memory(ttgl.int32, [M, N], smem_layout)
936+
smem_slice0 = smem.slice(SLICE_M_OFFSET, SLICE_M, dim=0)
937+
smem_slice1 = smem_slice0.slice(SLICE_N_OFFSET, SLICE_N, dim=1)
938+
939+
smem.store(in_data)
940+
941+
out_data = smem_slice1.load(blocked)
942+
943+
offs_m_store = ttgl.arange(0, SLICE_M, ttgl.SliceLayout(1, blocked))
944+
offs_n_store = ttgl.arange(0, SLICE_N, ttgl.SliceLayout(0, blocked))
945+
out_offs = offs_m_store[:, None] * SLICE_N + offs_n_store[None, :]
946+
ttgl.store(out_ptr + out_offs, out_data)
947+
948+
input = torch.arange(m * n, device="cuda").reshape(m, n).to(torch.int32)
949+
output = torch.zeros((slice_m, slice_n), dtype=torch.int32, device="cuda")
950+
ref_output = input[slice_m_offset:slice_m_offset + slice_m, slice_n_offset:slice_n_offset + slice_n]
951+
952+
kernel[(1, )](input, output, m, n, slice_m_offset, slice_n_offset, slice_m, slice_n, num_warps=num_warps)
953+
954+
assert (output == ref_output).all()

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,30 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
510510

511511
// -----
512512

513+
// CHECK-LABEL: padded_shared_layout_subslice_load_store
514+
515+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
516+
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [32, 32]}>
517+
#smem = #ttg.shared_memory
518+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 1], instrShape = [16, 16], isTransposed = true}>
519+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} {
520+
tt.func @padded_shared_layout_subslice_load_store(%arg0: tensor<32x32xf16, #blocked>) {
521+
// CHECK: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr<3>
522+
// CHECK-NOT: llvm.store
523+
%0 = ttg.local_alloc %arg0 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
524+
%1 = ttg.memdesc_subslice %0 [16, 0] : !ttg.memdesc<32x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32>
525+
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16>
526+
// CHECK-NOT: llvm.load
527+
%2 = ttg.local_load %1: !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
528+
// CHECK-COUNT-2: llvm.store {{.*}} : vector<4xf16>, !llvm.ptr<3>
529+
// CHECK-NOT: llvm.store
530+
ttg.local_store %2, %1 : tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable, 32x32>
531+
tt.return
532+
}
533+
}
534+
535+
// -----
536+
513537
// GFX950-LABEL: reduce_32x32
514538
// GFX950: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap"
515539
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {

test/TritonGPU/invalid.mlir

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,20 +452,29 @@ tt.func @async_copy_invalid_other_type(%input: tensor<64x64x!tt.ptr<f16>, #block
452452
// -----
453453

454454
#shared = #ttg.padded_shared<[4:+4] {offset=[[1, 0], [2, 0], [0, 1], [0, 2]], block=[]}>
455-
// expected-error @below {{Subslices with padded encodings are not yet implemented}}
456-
!unsupported_subslice = !ttg.memdesc<2x2xf32, #shared, #ttg.shared_memory, 4x4>
455+
// expected-error @below {{Mismatch in expected shape for dimension 0. Expected: 2, got: 4}}
456+
!out_dim_too_small = !ttg.memdesc<2x2xf32, #shared, #ttg.shared_memory>
457457

458458
// -----
459459

460460
#shared = #ttg.padded_shared<[4:+4] {offset=[[1, 0], [2, 0], [0, 1], [0, 2]], block=[]}>
461-
// expected-error @below {{Mismatch in expected shape for dimension 0. Expected: 4, got: 2}}
462-
!out_dim_too_small = !ttg.memdesc<2x2xf32, #shared, #ttg.shared_memory>
461+
// expected-error @below {{Mismatch in expected shape for dimension 0. Expected: 8, got: 4}}
462+
!out_dim_too_large = !ttg.memdesc<8x8xf32, #shared, #ttg.shared_memory>
463463

464464
// -----
465465

466-
#shared = #ttg.padded_shared<[4:+4] {offset=[[1, 0], [2, 0], [0, 1], [0, 2]], block=[]}>
467-
// expected-error @below {{Mismatch in expected shape for dimension 0. Expected: 4, got: 8}}
468-
!out_dim_too_large = !ttg.memdesc<8x8xf32, #shared, #ttg.shared_memory>
466+
// expected-error @below {{Mismatch of shape and order ranks in padded layout}}
467+
#shared = #ttg.padded_shared<[4:+4] {shape=[1, 2, 4], order=[1, 0]}>
468+
469+
// -----
470+
471+
#shared = #ttg.padded_shared<[4:+4] {shape=[32, 32], order=[1, 0]}>
472+
#smem = #ttg.shared_memory
473+
tt.func public @padded_subview_unsupported_size(%arg0: !ttg.memdesc<2x32x32xf32, #shared, #smem>) {
474+
// expected-error @+1 {{SubSlice of low rank PaddedSharedEncoding from higher rank tensors is not supported yet}}
475+
%a = ttg.memdesc_subslice %arg0 [0, 16, 0] : !ttg.memdesc<2x32x32xf32, #shared, #smem> -> !ttg.memdesc<2x16x32xf32, #shared, #smem, 2x32x32>
476+
tt.return
477+
}
469478

470479
// -----
471480

test/TritonGPU/memdesc-subview-split.mlir

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [1, 0]}>
5+
#padded = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [256, 128]}>
56
#smem = #ttg.shared_memory
67

78
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
@@ -10,30 +11,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
1011
%c0_i32 = arith.constant 0 : i32
1112
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf16, #shared, #smem, mutable>
1213
%1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<1x256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable>
13-
%c0_i32_0 = arith.constant 0 : i32
14-
%c0_i32_1 = arith.constant 0 : i32
1514
%2 = ttg.memdesc_subslice %1 [0, 0] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
16-
%c0_i32_2 = arith.constant 0 : i32
17-
%c32_i32 = arith.constant 32 : i32
1815
%3 = ttg.memdesc_subslice %1 [0, 32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
19-
%c0_i32_3 = arith.constant 0 : i32
20-
%c64_i32 = arith.constant 64 : i32
2116
%4 = ttg.memdesc_subslice %1 [0, 64] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
22-
%c0_i32_4 = arith.constant 0 : i32
23-
%c96_i32 = arith.constant 96 : i32
2417
%5 = ttg.memdesc_subslice %1 [0, 96] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
25-
%c128_i32 = arith.constant 128 : i32
26-
%c0_i32_5 = arith.constant 0 : i32
2718
%6 = ttg.memdesc_subslice %1 [128, 0] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
28-
%c128_i32_6 = arith.constant 128 : i32
29-
%c32_i32_7 = arith.constant 32 : i32
3019
%7 = ttg.memdesc_subslice %1 [128, 32] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
31-
%c128_i32_8 = arith.constant 128 : i32
32-
%c64_i32_9 = arith.constant 64 : i32
3320
%8 = ttg.memdesc_subslice %1 [128, 64] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
34-
%c128_i32_10 = arith.constant 128 : i32
35-
%c96_i32_11 = arith.constant 96 : i32
3621
%9 = ttg.memdesc_subslice %1 [128, 96] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128>
22+
23+
%padded = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable>
24+
%padded_indexed_explicit_alloc_shape = ttg.memdesc_index %padded[%c0_i32] : !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<256x128xf16, #padded, #smem, mutable, 1x256x128>
25+
%10 = ttg.memdesc_subslice %padded_indexed_explicit_alloc_shape [128, 96] : !ttg.memdesc<256x128xf16, #padded, #smem, mutable, 1x256x128> -> !ttg.memdesc<128x32xf16, #padded, #smem, mutable, 1x256x128>
26+
%padded_indexed_implicit_alloc_shape = ttg.memdesc_index %padded[%c0_i32] : !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<256x128xf16, #padded, #smem, mutable>
27+
%11 = ttg.memdesc_subslice %padded_indexed_implicit_alloc_shape [128, 96] : !ttg.memdesc<256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<128x32xf16, #padded, #smem, mutable, 256x128>
3728
tt.return
3829
}
3930
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ LogicalResult lowerLdStMatrix(
4646
return result;
4747
}
4848
}
49+
if (isa<PaddedSharedEncodingAttr>(memDescType.getEncoding())) {
50+
return failure();
51+
}
4952
auto memLayout = toLinearLayout(memDescType);
5053
auto cvt = regLayout.invertAndCompose(memLayout);
5154
auto kBlock = StringAttr::get(loc.getContext(), "block");

0 commit comments

Comments
 (0)