Skip to content

Commit 6edcd49

Browse files
authored
[AMD] Limit vec size for ds_read_tr + padded layouts by min interval (#8377)
The minimal interval restricts the number of contiguous elements in shared memory for padded layouts.
1 parent 483f9ea commit 6edcd49

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/Conversion/amd/ds_transpose.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#mma32_scaled = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32, 64], isTransposed = true}>
66
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
77
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
8+
#padding = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 64]}>
9+
#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}>
810
#smem = #ttg.shared_memory
911

1012
#linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
@@ -688,4 +690,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
688690
tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_invalid>
689691
tt.return
690692
}
693+
694+
// CHECK-LABEL: ds_transpose_with_padding
695+
tt.func @ds_transpose_with_padding(%arg0: !ttg.memdesc<128x64xf16, #padding, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
696+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
697+
// CHECK-NOT: rocdl.ds.read.tr16.b64
698+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
699+
700+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
701+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
702+
tt.return
703+
}
704+
705+
// CHECK-LABEL: ds_transpose_padding_interval_too_small
706+
tt.func @ds_transpose_padding_interval_too_small(%arg0: !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
707+
// CHECK-NOT: rocdl.ds.read.tr16.b64
708+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
709+
710+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
711+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
712+
tt.return
713+
}
691714
}

third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ class TransLocalLoadOpConversion
183183
// Need to have exactly needContigReg, otherwise we can't use ds_read_tr
184184
auto [elemsPerVec, permutation] =
185185
largestVectorisation(ctx, cvt, bitwidth, needContigReg);
186+
187+
if (paddedEnc)
188+
elemsPerVec = std::min<int>(elemsPerVec, paddedEnc.getMinInterval());
189+
186190
if (elemsPerVec != needContigReg)
187191
return failure();
188192

0 commit comments

Comments
 (0)