Skip to content

Commit 8b792c8

Browse files
authored
[AMD] Redesign stream pipeliner LDS layout selection logic (#8053)
This commit adapts the LDS layout selection logic in Stream Pipeliner so that we pick a common swizzled shared memory layout with vecSize = max kWidth of all users.
1 parent fb68aea commit 8b792c8

File tree

7 files changed

+180
-63
lines changed

7 files changed

+180
-63
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2" -canonicalize | FileCheck %s
2+
3+
// Pick a common shared memory layout with vec = max kWidth of all users.
4+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}>
5+
// CHECK-NOT: #ttg.swizzled_shared
6+
// CHECK{LITERAL}: #smem = #ttg.shared_memory
7+
// CHECK-LABEL: test_lds_layout_selection
8+
9+
// CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
10+
// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]]
11+
12+
// CHECK: scf.for {{.+}} iter_args({{.*}}, %[[MEMDESC_IDX_ITER:.+]] = %[[MEMDESC_IDX]]) -> ({{.+}})
13+
// CHECK: %[[LOAD:.+]] = tt.load {{.+}} : tensor<64x16x!tt.ptr<f16>, #blocked>
14+
// CHECK: %[[LOCAL_LOAD_TRANS:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #linear>
15+
// CHECK: %[[LOCAL_LOAD_DIRECT:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
16+
// CHECK: tt.dot {{.+}}, %[[LOCAL_LOAD_DIRECT]], {{.+}}
17+
// CHECK: %[[TRANS:.+]] = tt.trans %[[LOCAL_LOAD_TRANS]] {{.+}} : {{.+}} -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
18+
// CHECK: tt.dot {{.+}}, %[[TRANS]], {{.+}}
19+
// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]]
20+
// CHECK: ttg.local_store %[[LOAD]], %[[MEMDESC_IDX]]
21+
// CHECK: scf.yield
22+
23+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
24+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}>
25+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
26+
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
27+
28+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
29+
tt.func public @test_lds_layout_selection(
30+
%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>,
31+
%out0 : tensor<128x16x!tt.ptr<f32>, #blocked>,
32+
%out1 : tensor<128x64x!tt.ptr<f32>, #blocked>
33+
) {
34+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
35+
%cst_1 = arith.constant dense<0.693147182> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
36+
%cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
37+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
38+
%c0_i32 = arith.constant 0 : i32
39+
%c1_i32 = arith.constant 1 : i32
40+
%c8_i32 = arith.constant 8 : i32
41+
42+
%0:2 = scf.for %arg1 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg2 = %cst_0, %arg3 = %cst_3) -> (tensor<128x16xf32, #mma1>, tensor<128x64xf32, #mma>) : i32 {
43+
%1 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked>
44+
%2 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear>
45+
%3 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
46+
%4 = tt.dot %cst_1, %3, %arg2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
47+
%5 = tt.trans %2 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
48+
%6 = tt.dot %cst_2, %5, %arg3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma>
49+
scf.yield %4, %6 : tensor<128x16xf32, #mma1>, tensor<128x64xf32, #mma>
50+
}
51+
52+
%7 = ttg.convert_layout %0#0 : tensor<128x16xf32, #mma1> -> tensor<128x16xf32, #blocked>
53+
%8 = ttg.convert_layout %0#1 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
54+
tt.store %out0, %7 : tensor<128x16x!tt.ptr<f32>, #blocked>
55+
tt.store %out1, %8 : tensor<128x64x!tt.ptr<f32>, #blocked>
56+
tt.return
57+
}
58+
}
59+
// -----
60+
61+
// Verify that a common shared memory layout is chosen for users with different kWidth and opIdx.
62+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}>
63+
// CHECK-NOT: #ttg.swizzled_shared
64+
// CHECK{LITERAL}: #smem = #ttg.shared_memory
65+
// CHECK-LABEL: test_lds_layout_selection_different_opIdx
66+
67+
// CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
68+
// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]]
69+
70+
// CHECK: scf.for {{.+}} iter_args({{.*}}, %[[MEMDESC_IDX_ITER:.+]] = %[[MEMDESC_IDX]]) -> ({{.+}})
71+
// CHECK: %[[LOAD:.+]] = tt.load {{.+}} : tensor<64x16x!tt.ptr<f16>, #blocked>
72+
// CHECK: %[[LOCAL_LOAD_TRANS:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #linear>
73+
// CHECK: %[[LOCAL_LOAD_DIRECT:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
74+
// CHECK: tt.dot %[[LOCAL_LOAD_DIRECT]], {{.+}}
75+
// CHECK: %[[TRANS:.+]] = tt.trans %[[LOCAL_LOAD_TRANS]] {{.+}} : {{.+}} -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
76+
// CHECK: tt.dot {{.+}}, %[[TRANS]], {{.+}}
77+
// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]]
78+
// CHECK: ttg.local_store %[[LOAD]], %[[MEMDESC_IDX]]
79+
// CHECK: scf.yield
80+
81+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
82+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}>
83+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
84+
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
85+
86+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
87+
tt.func public @test_lds_layout_selection_different_opIdx(
88+
%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>,
89+
%out0 : tensor<64x64x!tt.ptr<f32>, #blocked>,
90+
%out1 : tensor<128x64x!tt.ptr<f32>, #blocked>
91+
) {
92+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1>
93+
%cst_1 = arith.constant dense<0.693147182> : tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
94+
%cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
95+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
96+
%c0_i32 = arith.constant 0 : i32
97+
%c1_i32 = arith.constant 1 : i32
98+
%c8_i32 = arith.constant 8 : i32
99+
100+
%0:2 = scf.for %arg1 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg2 = %cst_0, %arg3 = %cst_3) -> (tensor<64x64xf32, #mma1>, tensor<128x64xf32, #mma>) : i32 {
101+
%1 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked>
102+
%2 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear>
103+
%3 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
104+
%4 = tt.dot %3, %cst_1, %arg2 : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<64x64xf32, #mma1>
105+
%5 = tt.trans %2 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
106+
%6 = tt.dot %cst_2, %5, %arg3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma>
107+
scf.yield %4, %6 : tensor<64x64xf32, #mma1>, tensor<128x64xf32, #mma>
108+
}
109+
110+
%7 = ttg.convert_layout %0#0 : tensor<64x64xf32, #mma1> -> tensor<64x64xf32, #blocked>
111+
%8 = ttg.convert_layout %0#1 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
112+
tt.store %out0, %7 : tensor<64x64x!tt.ptr<f32>, #blocked>
113+
tt.store %out1, %8 : tensor<128x64x!tt.ptr<f32>, #blocked>
114+
tt.return
115+
}
116+
}

test/TritonGPU/amd/amd-stream-loop-assume.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2" -canonicalize | FileCheck %s
22

33
// matmul: 128x32 @ 32x128 -> 128x128
4-
#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5-
#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
66
#ALs0 = #ttg.slice<{parent=#AL, dim=0}>
77
#BLs0 = #ttg.slice<{parent=#BL, dim=0}>
88
#BLs1 = #ttg.slice<{parent=#BL, dim=1}>
9-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
10-
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
11-
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
9+
#C = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
10+
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 4}>
11+
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 4}>
1212

1313
// CHECK-LABEL: tt.func @assume_matmul
1414
// CHECK-COUNT-2: tt.load
@@ -27,7 +27,7 @@
2727
// CHECK: tt.dot
2828
// CHECK-NOT: tt.dot
2929

30-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
30+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
3131
tt.func @assume_matmul(%lb : index, %ub : index, %step : index,
3232
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
3333
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {

test/TritonGPU/amd/amd-stream-prefetch.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
// RUN: triton-opt %s -tritonamdgpu-stream-pipeline="num_stages=2 local_prefetch=1" -canonicalize | FileCheck %s --check-prefixes=LOCAL_1
55

66
// matmul: 128x32 @ 32x128 -> 128x128
7-
#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
8-
#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
7+
#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
8+
#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
99
#ALs0 = #ttg.slice<{parent=#AL, dim=0}>
1010
#BLs0 = #ttg.slice<{parent=#BL, dim=0}>
1111
#BLs1 = #ttg.slice<{parent=#BL, dim=1}>
12-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
13-
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
14-
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
12+
#C = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
13+
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 4}>
14+
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 4}>
1515

1616
// An extra register buffer for global loads.
1717
// GLOBAL_1-LABEL: tt.func @matmul_loop
@@ -74,7 +74,7 @@
7474
// LOCAL_1: tt.dot
7575
// LOCAL_1-NOT: tt.dot
7676

77-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
77+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
7878
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
7979
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
8080
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -707,28 +707,6 @@ bool isChainDotHead(tt::DotOpInterface dotOp, unsigned opIdx) {
707707
return false;
708708
}
709709

710-
bool hasTransInDefChain(tt::DotOpInterface dotOp, unsigned opIdx) {
711-
auto isInSameRegion = [&dotOp](Operation *op) {
712-
return op->getParentRegion() == dotOp->getParentRegion();
713-
};
714-
715-
BackwardSliceOptions bwdOpt;
716-
bwdOpt.omitBlockArguments = true;
717-
bwdOpt.filter = isInSameRegion;
718-
SetVector<Operation *> bwdSlices;
719-
Operation *dotOperand = (opIdx == 0) ? dotOp.getA().getDefiningOp()
720-
: dotOp.getB().getDefiningOp();
721-
722-
if (!dotOperand)
723-
return false;
724-
(void)getBackwardSlice(dotOperand, &bwdSlices, bwdOpt);
725-
if (llvm::find_if(bwdSlices, [](Operation *op) {
726-
return isa<tt::TransOp>(op);
727-
}) != bwdSlices.end())
728-
return true;
729-
return false;
730-
}
731-
732710
bool isChainDotTail(tt::DotOpInterface dotOp) {
733711
auto isInSameRegion = [&dotOp](Operation *op) {
734712
return op->getParentRegion() == dotOp->getParentRegion();

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ bool isUsedByDotScaledOp(Operation *op);
115115
// in the same region
116116
bool isChainDotHead(mlir::triton::DotOpInterface dotOp, unsigned opIdx = 0);
117117

118-
// Check if given operand of this tt.dot is the result of a tt.trans
119-
// in the same region
120-
bool hasTransInDefChain(mlir::triton::DotOpInterface dotOp, unsigned opIdx);
121-
122118
// Check if the opA of this tl.dot is the result of another tl.dot
123119
// in the same region
124120
bool isChainDotTail(mlir::triton::DotOpInterface dotOp);

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
namespace tt = mlir::triton;
1717
namespace ttg = mlir::triton::gpu;
18-
using ::mlir::LLVM::AMD::hasTransInDefChain;
1918
using ::mlir::LLVM::AMD::isChainDotHead;
2019
using ::mlir::LLVM::AMD::isChainDotTail;
2120
using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType;
@@ -563,18 +562,6 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
563562
if (is16BitElemTy && isDotChainTail) {
564563
kWidth = 4;
565564
}
566-
// For FA bwd kernel (detected using hasTransInDefChain), depending on
567-
// whether the dot is a head or tail in the chain, we adjust the kWidth
568-
// accordingly. This will enable us to create the same shared encoding per
569-
// pair of tt.dot ops that both use the same tt.load result, one directly
570-
// and one via tt.trans, later in the pass pipeline.
571-
if (is16BitElemTy && hasTransInDefChain(dotOp, 1u)) {
572-
if (isChainDotHead(dotOp)) {
573-
kWidth = 4;
574-
} else if (isDotChainTail) {
575-
kWidth = 8;
576-
}
577-
}
578565

579566
Value newDot;
580567
if (withScale) {

0 commit comments

Comments
 (0)