@@ -22,75 +22,119 @@ namespace {
22
22
23
23
// clang-format off
24
24
// Converts:
25
- // %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked>
26
- // %r = tt.reshape %l : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked4>
27
- // %t = tt.trans %r {order = array<i32: 0, 2, 1>} : tensor<128x2x128xf32, #blocked4> -> tensor<128x128x2xf32, #blocked5>
28
- // %outLHS, %outRHS = tt.split %t : tensor<128x128x2xf32, #blocked5> -> tensor<128x128xf32, #blocked2>
29
- // To:
30
- // %o0 = ttng.tmem_subslice %o { N = 0 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
31
- // %outLHS = ttng.tmem_load %o0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
32
- // %o1 = ttng.tmem_subslice %o { N = 128 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
33
- // %outRHS = ttng.tmem_load %o1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
25
+ // %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
26
+ // -> tensor<128x256xf32, #blocked>
27
+ // %r = tt.reshape %l : tensor<128x256xf32, #blocked>
28
+ // -> tensor<128x2x128xf32, #blocked4>
29
+ // %t = tt.trans %r {order = array<i32: 0, 2, 1>}
30
+ // -> tensor<128x128x2xf32, #blocked5>
31
+ // %lhs, %rhs = tt.split %t
32
+ //
33
+ // becomes
34
+ // %o0 = ttng.tmem_subslice %o { N = 0 }
35
+ // %lhs = ttng.tmem_load %o0
36
+ // %o1 = ttng.tmem_subslice %o { N = 128 }
37
+ // %rhs = ttng.tmem_load %o1
38
+ //
39
+ // and if %lhs / %rhs are split again through the same reshape->trans->split
40
+ // pattern, the transformation is can match again so that each further
41
+ // split is materialised as an independent `ttng.tmem_subslice` / `ttng.tmem_load`
42
+ // pair. Consequently, a chain such as
43
+ //
44
+ // acc0, acc1 = split(permute(reshape(acc , ...)))
45
+ // acc00, acc01 = split(permute(reshape(acc0, ...)))
46
+ // acc10, acc11 = split(permute(reshape(acc1, ...)))
47
+ //
48
+ // is lowered to four independent TMEM loads operating on four disjoint
49
+ // subslices.
50
+ //
34
51
// clang-format on
35
- // This will change the layout of the destination tensor to distribute each
36
- // slice across warps. It currently only supports simple cases where tmem can be
37
- // sliced easily. This could be extended if needed with more powerful slicing
38
- // support of tmem.
52
+ // Strip away all intermediate ttg.convert_layout ops to reach the true
53
+ // producer.
54
+ static Value stripConvertLayout (Value v) {
55
+ while (auto cvt = v.getDefiningOp <ttg::ConvertLayoutOp>())
56
+ v = cvt.getSrc ();
57
+ return v;
58
+ }
59
+
39
60
class TMemSplitLoadPattern : public OpRewritePattern <SplitOp> {
40
61
public:
41
62
using OpRewritePattern::OpRewritePattern;
42
63
43
64
LogicalResult matchAndRewrite (SplitOp splitOp,
44
65
PatternRewriter &rewriter) const override {
45
- auto src = splitOp.getSrc ();
46
- // Skip convert layout ops.
47
- while (auto cvt = src.getDefiningOp <ttg::ConvertLayoutOp>()) {
48
- src = cvt.getSrc ();
49
- }
50
- // Only support splitting N dimension on the outer most.
66
+ // -----------------------------------------------------------------------
67
+ // Match the pattern:
68
+ // splitOp
69
+ // ^ |
70
+ // | +-- transOp(order = [0, 2, 1])
71
+ // | ^ |
72
+ // | | +-- reshapeOp
73
+ // | | ^ |
74
+ // | | | +-- (maybe convert_layout)
75
+ // | | +-- tmemLoad
76
+ // -----------------------------------------------------------------------
77
+
78
+ // Starting from the split source, peel off convert_layouts if any.
79
+ Value src = stripConvertLayout (splitOp.getSrc ());
51
80
auto transOp = src.getDefiningOp <TransOp>();
52
81
if (!transOp || transOp.getOrder () != ArrayRef<int >({0 , 2 , 1 }))
53
82
return failure ();
54
83
auto reshapeOp = transOp.getSrc ().getDefiningOp <ReshapeOp>();
55
84
if (!reshapeOp)
56
85
return failure ();
57
- auto shape = reshapeOp.getResult ().getType ().getShape ();
58
- if (shape[0 ] != reshapeOp.getSrc ().getType ().getShape ()[0 ])
59
- return failure ();
60
- auto tmemLoad = reshapeOp.getSrc ().getDefiningOp <TMEMLoadOp>();
86
+
87
+ // Peel off convert_layouts *below* the reshape as well. This is required
88
+ // for the recursive case where the producer of the reshape is the result
89
+ // of an earlier optimisation pass (i.e. a convert_layout of a previous
90
+ // tmem_load).
91
+ Value reshapeSrc = stripConvertLayout (reshapeOp.getSrc ());
92
+ auto tmemLoad = reshapeSrc.getDefiningOp <TMEMLoadOp>();
61
93
if (!tmemLoad)
62
94
return failure ();
63
- // We found a tmem_load that is split on the N dimension. We can split it
64
- // into multiple tmem_loads.
95
+
96
+ auto shape = reshapeOp.getResult ().getType ().getShape ();
97
+ // Ensure M dimension is preserved by the reshape.
98
+ if (shape[0 ] != cast<RankedTensorType>(reshapeSrc.getType ()).getShape ()[0 ])
99
+ return failure ();
65
100
int mDim = getShapePerCTA (tmemLoad.getSrc ().getType ())[0 ];
66
101
// TODO: enable other M cases. (the layout is a bit more complex).
67
102
if (mDim != 128 )
68
103
return failure ();
69
104
int splitNSize = shape[2 ];
70
105
if (splitNSize < 8 )
71
106
return failure ();
72
- Value tmem = tmemLoad.getSrc ();
107
+
108
+ // Create the two TMEM subslices and their corresponding loads.
109
+ Value tmem = tmemLoad.getSrc (); // Could itself be a subslice.
73
110
int numWarps = ttg::lookupNumWarps (tmemLoad);
74
111
rewriter.setInsertionPoint (tmemLoad);
75
- // First slice.
76
- Value subSlice0 =
77
- rewriter.create <TMEMSubSliceOp>(tmemLoad.getLoc (), tmem, 0 , splitNSize);
78
- Attribute distLayout = getTmemCompatibleLayout (
79
- mDim , splitNSize, splitOp.getOutLHS ().getType (), numWarps);
80
- RankedTensorType newLoadType = RankedTensorType::get (
81
- splitOp.getOutLHS ().getType ().getShape (),
82
- splitOp.getOutLHS ().getType ().getElementType (), distLayout);
83
- auto load0 =
84
- rewriter.create <TMEMLoadOp>(tmemLoad.getLoc (), newLoadType, subSlice0);
85
- auto cvt0 = rewriter.create <ttg::ConvertLayoutOp>(
86
- tmemLoad.getLoc (), splitOp.getOutLHS ().getType (), load0);
87
- // Second slice.
88
- Value subSlice1 = rewriter.create <TMEMSubSliceOp>(tmemLoad.getLoc (), tmem,
89
- splitNSize, splitNSize);
90
- auto load1 =
91
- rewriter.create <TMEMLoadOp>(tmemLoad.getLoc (), newLoadType, subSlice1);
92
- auto cvt1 = rewriter.create <ttg::ConvertLayoutOp>(
93
- tmemLoad.getLoc (), splitOp.getOutRHS ().getType (), load1);
112
+
113
+ auto createSliceLoad =
114
+ [&](int64_t nOffset) -> std::pair<TMEMLoadOp, ttg::ConvertLayoutOp> {
115
+ // Generate the subslice op.
116
+ Value subSlice = rewriter.create <TMEMSubSliceOp>(tmemLoad.getLoc (), tmem,
117
+ nOffset, splitNSize);
118
+
119
+ // Choose a layout compatible with the slice size.
120
+ Attribute distLayout = getTmemCompatibleLayout (
121
+ mDim , splitNSize, splitOp.getOutLHS ().getType (), numWarps);
122
+
123
+ RankedTensorType newLoadType = RankedTensorType::get (
124
+ splitOp.getOutLHS ().getType ().getShape (),
125
+ splitOp.getOutLHS ().getType ().getElementType (), distLayout);
126
+
127
+ // Generate the load and convert_layout back to the original layout.
128
+ auto load =
129
+ rewriter.create <TMEMLoadOp>(tmemLoad.getLoc (), newLoadType, subSlice);
130
+ auto cvt = rewriter.create <ttg::ConvertLayoutOp>(
131
+ tmemLoad.getLoc (), splitOp.getOutLHS ().getType (), load);
132
+
133
+ return {load, cvt};
134
+ };
135
+
136
+ auto [load0, cvt0] = createSliceLoad (/* nOffset=*/ 0 );
137
+ auto [load1, cvt1] = createSliceLoad (/* nOffset=*/ splitNSize);
94
138
rewriter.replaceOp (splitOp, {cvt0, cvt1});
95
139
return success ();
96
140
}
0 commit comments