Skip to content

Commit d141ab8

Browse files
yongjikaeng-openai
andauthored
[BACKEND] Allow tmem split pass to match splits of tmem subslices (#7044)
This allows the TMEM load to be split into 4 (or even larger pow-of-2) spearate loads. More smaller loads can decrease register pressure in matmul loop epilogues. * Cherry-picked from PR #6566 and applied on top of latest main. --------- Co-authored-by: Austin Eng <[email protected]>
1 parent ffc614d commit d141ab8

File tree

2 files changed

+135
-45
lines changed

2 files changed

+135
-45
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,75 +22,119 @@ namespace {
2222

2323
// clang-format off
2424
// Converts:
25-
// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked>
26-
// %r = tt.reshape %l : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked4>
27-
// %t = tt.trans %r {order = array<i32: 0, 2, 1>} : tensor<128x2x128xf32, #blocked4> -> tensor<128x128x2xf32, #blocked5>
28-
// %outLHS, %outRHS = tt.split %t : tensor<128x128x2xf32, #blocked5> -> tensor<128x128xf32, #blocked2>
29-
// To:
30-
// %o0 = ttng.tmem_subslice %o { N = 0 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
31-
// %outLHS = ttng.tmem_load %o0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
32-
// %o1 = ttng.tmem_subslice %o { N = 128 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
33-
// %outRHS = ttng.tmem_load %o1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
25+
// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
26+
// -> tensor<128x256xf32, #blocked>
27+
// %r = tt.reshape %l : tensor<128x256xf32, #blocked>
28+
// -> tensor<128x2x128xf32, #blocked4>
29+
// %t = tt.trans %r {order = array<i32: 0, 2, 1>}
30+
// -> tensor<128x128x2xf32, #blocked5>
31+
// %lhs, %rhs = tt.split %t
32+
//
33+
// becomes
34+
// %o0 = ttng.tmem_subslice %o { N = 0 }
35+
// %lhs = ttng.tmem_load %o0
36+
// %o1 = ttng.tmem_subslice %o { N = 128 }
37+
// %rhs = ttng.tmem_load %o1
38+
//
39+
// and if %lhs / %rhs are split again through the same reshape->trans->split
40+
// pattern, the transformation is can match again so that each further
41+
// split is materialised as an independent `ttng.tmem_subslice` / `ttng.tmem_load`
42+
// pair. Consequently, a chain such as
43+
//
44+
// acc0, acc1 = split(permute(reshape(acc , ...)))
45+
// acc00, acc01 = split(permute(reshape(acc0, ...)))
46+
// acc10, acc11 = split(permute(reshape(acc1, ...)))
47+
//
48+
// is lowered to four independent TMEM loads operating on four disjoint
49+
// subslices.
50+
//
3451
// clang-format on
35-
// This will change the layout of the destination tensor to distribute each
36-
// slice across warps. It currently only supports simple cases where tmem can be
37-
// sliced easily. This could be extended if needed with more powerful slicing
38-
// support of tmem.
52+
// Strip away all intermediate ttg.convert_layout ops to reach the true
53+
// producer.
54+
static Value stripConvertLayout(Value v) {
55+
while (auto cvt = v.getDefiningOp<ttg::ConvertLayoutOp>())
56+
v = cvt.getSrc();
57+
return v;
58+
}
59+
3960
class TMemSplitLoadPattern : public OpRewritePattern<SplitOp> {
4061
public:
4162
using OpRewritePattern::OpRewritePattern;
4263

4364
LogicalResult matchAndRewrite(SplitOp splitOp,
4465
PatternRewriter &rewriter) const override {
45-
auto src = splitOp.getSrc();
46-
// Skip convert layout ops.
47-
while (auto cvt = src.getDefiningOp<ttg::ConvertLayoutOp>()) {
48-
src = cvt.getSrc();
49-
}
50-
// Only support splitting N dimension on the outer most.
66+
// -----------------------------------------------------------------------
67+
// Match the pattern:
68+
// splitOp
69+
// ^ |
70+
// | +-- transOp(order = [0, 2, 1])
71+
// | ^ |
72+
// | | +-- reshapeOp
73+
// | | ^ |
74+
// | | | +-- (maybe convert_layout)
75+
// | | +-- tmemLoad
76+
// -----------------------------------------------------------------------
77+
78+
// Starting from the split source, peel off convert_layouts if any.
79+
Value src = stripConvertLayout(splitOp.getSrc());
5180
auto transOp = src.getDefiningOp<TransOp>();
5281
if (!transOp || transOp.getOrder() != ArrayRef<int>({0, 2, 1}))
5382
return failure();
5483
auto reshapeOp = transOp.getSrc().getDefiningOp<ReshapeOp>();
5584
if (!reshapeOp)
5685
return failure();
57-
auto shape = reshapeOp.getResult().getType().getShape();
58-
if (shape[0] != reshapeOp.getSrc().getType().getShape()[0])
59-
return failure();
60-
auto tmemLoad = reshapeOp.getSrc().getDefiningOp<TMEMLoadOp>();
86+
87+
// Peel off convert_layouts *below* the reshape as well. This is required
88+
// for the recursive case where the producer of the reshape is the result
89+
// of an earlier optimisation pass (i.e. a convert_layout of a previous
90+
// tmem_load).
91+
Value reshapeSrc = stripConvertLayout(reshapeOp.getSrc());
92+
auto tmemLoad = reshapeSrc.getDefiningOp<TMEMLoadOp>();
6193
if (!tmemLoad)
6294
return failure();
63-
// We found a tmem_load that is split on the N dimension. We can split it
64-
// into multiple tmem_loads.
95+
96+
auto shape = reshapeOp.getResult().getType().getShape();
97+
// Ensure M dimension is preserved by the reshape.
98+
if (shape[0] != cast<RankedTensorType>(reshapeSrc.getType()).getShape()[0])
99+
return failure();
65100
int mDim = getShapePerCTA(tmemLoad.getSrc().getType())[0];
66101
// TODO: enable other M cases. (the layout is a bit more complex).
67102
if (mDim != 128)
68103
return failure();
69104
int splitNSize = shape[2];
70105
if (splitNSize < 8)
71106
return failure();
72-
Value tmem = tmemLoad.getSrc();
107+
108+
// Create the two TMEM subslices and their corresponding loads.
109+
Value tmem = tmemLoad.getSrc(); // Could itself be a subslice.
73110
int numWarps = ttg::lookupNumWarps(tmemLoad);
74111
rewriter.setInsertionPoint(tmemLoad);
75-
// First slice.
76-
Value subSlice0 =
77-
rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem, 0, splitNSize);
78-
Attribute distLayout = getTmemCompatibleLayout(
79-
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
80-
RankedTensorType newLoadType = RankedTensorType::get(
81-
splitOp.getOutLHS().getType().getShape(),
82-
splitOp.getOutLHS().getType().getElementType(), distLayout);
83-
auto load0 =
84-
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice0);
85-
auto cvt0 = rewriter.create<ttg::ConvertLayoutOp>(
86-
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load0);
87-
// Second slice.
88-
Value subSlice1 = rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem,
89-
splitNSize, splitNSize);
90-
auto load1 =
91-
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice1);
92-
auto cvt1 = rewriter.create<ttg::ConvertLayoutOp>(
93-
tmemLoad.getLoc(), splitOp.getOutRHS().getType(), load1);
112+
113+
auto createSliceLoad =
114+
[&](int64_t nOffset) -> std::pair<TMEMLoadOp, ttg::ConvertLayoutOp> {
115+
// Generate the subslice op.
116+
Value subSlice = rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem,
117+
nOffset, splitNSize);
118+
119+
// Choose a layout compatible with the slice size.
120+
Attribute distLayout = getTmemCompatibleLayout(
121+
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
122+
123+
RankedTensorType newLoadType = RankedTensorType::get(
124+
splitOp.getOutLHS().getType().getShape(),
125+
splitOp.getOutLHS().getType().getElementType(), distLayout);
126+
127+
// Generate the load and convert_layout back to the original layout.
128+
auto load =
129+
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice);
130+
auto cvt = rewriter.create<ttg::ConvertLayoutOp>(
131+
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load);
132+
133+
return {load, cvt};
134+
};
135+
136+
auto [load0, cvt0] = createSliceLoad(/*nOffset=*/0);
137+
auto [load1, cvt1] = createSliceLoad(/*nOffset=*/splitNSize);
94138
rewriter.replaceOp(splitOp, {cvt0, cvt1});
95139
return success();
96140
}

test/TritonNvidiaGPU/tmem_layouts.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
2828

2929
// -----
3030

31+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
32+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 2, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [4, 1, 2], order = [1, 2, 0]}>
33+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 2, 1], order = [2, 1, 0]}>
34+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
35+
#blocked7 = #ttg.blocked<{sizePerThread = [1, 1, 128], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 2, 1], order = [0, 2, 1]}>
36+
#blocked8 = #ttg.blocked<{sizePerThread = [1, 128, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 2], order = [0, 1, 2]}>
37+
#linear = #ttg.linear<{register = [[0, 0, 1], [0, 64, 0], [4, 0, 0], [8, 0, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 16, 0]], warp = [
38+
[0, 32, 0], [1, 0, 0], [2, 0, 0]], block = []}>
39+
#linear1 = #ttg.linear<{register = [[0, 64], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [1, 0], [2, 0]], block = []}>
40+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
41+
42+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100"} {
43+
// CHECK-LABEL: @subtile4_tmem_load
44+
tt.func public @subtile4_tmem_load(%arg0: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>) -> (tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>) {
45+
// CHECK: %[[S0:.+]] = ttng.tmem_subslice %{{.+}} {N = 0 : i32}
46+
// CHECK: %[[S1:.+]] = ttng.tmem_subslice %[[S0]] {N = 0 : i32}
47+
// CHECK: %[[L1:.+]] = ttng.tmem_load %[[S1]] : !ttg.memdesc<128x64xf32
48+
// CHECK: %[[C1:.+]] = ttg.convert_layout %[[L1]]
49+
// CHECK: %[[S2:.+]] = ttng.tmem_subslice %[[S0]] {N = 64 : i32}
50+
// CHECK: %[[L2:.+]] = ttng.tmem_load %[[S2]] : !ttg.memdesc<128x64xf32
51+
// CHECK: %[[C2:.+]] = ttg.convert_layout %[[L2]]
52+
// CHECK: %[[S3:.+]] = ttng.tmem_subslice %{{.+}} {N = 128 : i32}
53+
// CHECK: %[[S4:.+]] = ttng.tmem_subslice %[[S3]] {N = 0 : i32}
54+
// CHECK: %[[L4:.+]] = ttng.tmem_load %[[S4]] : !ttg.memdesc<128x64xf32
55+
// CHECK: %[[C4:.+]] = ttg.convert_layout %[[L4]]
56+
// CHECK: %[[S5:.+]] = ttng.tmem_subslice %[[S3]] {N = 64 : i32}
57+
// CHECK: %[[L5:.+]] = ttng.tmem_load %[[S5]] : !ttg.memdesc<128x64xf32
58+
// CHECK: %[[C5:.+]] = ttg.convert_layout %[[L5]]
59+
// CHECK: tt.return %[[C1]], %[[C2]], %[[C4]], %[[C5]]
60+
%0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked>
61+
%1 = tt.reshape %0 : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked7>
62+
%2 = tt.trans %1 {order = array<i32: 0, 2, 1>} : tensor<128x2x128xf32, #blocked7> -> tensor<128x128x2xf32, #blocked8>
63+
%3 = ttg.convert_layout %2 : tensor<128x128x2xf32, #blocked8> -> tensor<128x128x2xf32, #linear>
64+
%outLHS, %outRHS = tt.split %3 : tensor<128x128x2xf32, #linear> -> tensor<128x128xf32, #linear1>
65+
%4 = tt.reshape %outLHS : tensor<128x128xf32, #linear1> -> tensor<128x2x64xf32, #blocked2>
66+
%5 = tt.trans %4 {order = array<i32: 0, 2, 1>} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3>
67+
%outLHS_1, %outRHS_1 = tt.split %5 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64xf32, #blocked4>
68+
%6 = tt.reshape %outRHS : tensor<128x128xf32, #linear1> -> tensor<128x2x64xf32, #blocked2>
69+
%7 = tt.trans %6 {order = array<i32: 0, 2, 1>} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3>
70+
%outLHS_2, %outRHS_2 = tt.split %7 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64xf32, #blocked4>
71+
tt.return %outLHS_1, %outRHS_1, %outLHS_2, %outRHS_2 : tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>, tensor<128x64xf32, #blocked4>
72+
}
73+
}
74+
75+
// -----
76+
3177
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
3278
#blocked5 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
3379
#blocked6 = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>

0 commit comments

Comments
 (0)