Skip to content

Commit 081ae01

Browse files
Merge commit 'd629bda9b50a0bbabf31a9c145d986b43ad49965'
2 parents 942c79f + d629bda commit 081ae01

File tree

38 files changed

+998
-283
lines changed

38 files changed

+998
-283
lines changed

.github/workflows/documentation.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ on:
44
schedule:
55
- cron: "0 0 * * *"
66

7-
permissions: read-all
7+
permissions:
8+
contents: write
89

910
jobs:
1011
Build-Documentation:
@@ -15,7 +16,7 @@ jobs:
1516
- name: Checkout branch
1617
uses: actions/checkout@v4
1718
with:
18-
token: ${{ secrets.CI_PAT }}
19+
token: ${{ secrets.GITHUB_TOKEN }}
1920
fetch-depth: 0
2021

2122
- name: Clear docs

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeTMemLayouts.cpp

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,75 +22,119 @@ namespace {
2222

2323
// clang-format off
2424
// Converts:
25-
// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked>
26-
// %r = tt.reshape %l : tensor<128x256xf32, #blocked> -> tensor<128x2x128xf32, #blocked4>
27-
// %t = tt.trans %r {order = array<i32: 0, 2, 1>} : tensor<128x2x128xf32, #blocked4> -> tensor<128x128x2xf32, #blocked5>
28-
// %outLHS, %outRHS = tt.split %t : tensor<128x128x2xf32, #blocked5> -> tensor<128x128xf32, #blocked2>
29-
// To:
30-
// %o0 = ttng.tmem_subslice %o { N = 0 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
31-
// %outLHS = ttng.tmem_load %o0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
32-
// %o1 = ttng.tmem_subslice %o { N = 128 }: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
33-
// %outRHS = ttng.tmem_load %o1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
25+
// %l = ttng.tmem_load %o : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
26+
// -> tensor<128x256xf32, #blocked>
27+
// %r = tt.reshape %l : tensor<128x256xf32, #blocked>
28+
// -> tensor<128x2x128xf32, #blocked4>
29+
// %t = tt.trans %r {order = array<i32: 0, 2, 1>}
30+
// -> tensor<128x128x2xf32, #blocked5>
31+
// %lhs, %rhs = tt.split %t
32+
//
33+
// becomes
34+
// %o0 = ttng.tmem_subslice %o { N = 0 }
35+
// %lhs = ttng.tmem_load %o0
36+
// %o1 = ttng.tmem_subslice %o { N = 128 }
37+
// %rhs = ttng.tmem_load %o1
38+
//
39+
// and if %lhs / %rhs are split again through the same reshape->trans->split
40+
// pattern, the transformation is can match again so that each further
41+
// split is materialised as an independent `ttng.tmem_subslice` / `ttng.tmem_load`
42+
// pair. Consequently, a chain such as
43+
//
44+
// acc0, acc1 = split(permute(reshape(acc , ...)))
45+
// acc00, acc01 = split(permute(reshape(acc0, ...)))
46+
// acc10, acc11 = split(permute(reshape(acc1, ...)))
47+
//
48+
// is lowered to four independent TMEM loads operating on four disjoint
49+
// subslices.
50+
//
3451
// clang-format on
35-
// This will change the layout of the destination tensor to distribute each
36-
// slice across warps. It currently only supports simple cases where tmem can be
37-
// sliced easily. This could be extended if needed with more powerful slicing
38-
// support of tmem.
52+
// Strip away all intermediate ttg.convert_layout ops to reach the true
53+
// producer.
54+
static Value stripConvertLayout(Value v) {
55+
while (auto cvt = v.getDefiningOp<ttg::ConvertLayoutOp>())
56+
v = cvt.getSrc();
57+
return v;
58+
}
59+
3960
class TMemSplitLoadPattern : public OpRewritePattern<SplitOp> {
4061
public:
4162
using OpRewritePattern::OpRewritePattern;
4263

4364
LogicalResult matchAndRewrite(SplitOp splitOp,
4465
PatternRewriter &rewriter) const override {
45-
auto src = splitOp.getSrc();
46-
// Skip convert layout ops.
47-
while (auto cvt = src.getDefiningOp<ttg::ConvertLayoutOp>()) {
48-
src = cvt.getSrc();
49-
}
50-
// Only support splitting N dimension on the outer most.
66+
// -----------------------------------------------------------------------
67+
// Match the pattern:
68+
// splitOp
69+
// ^ |
70+
// | +-- transOp(order = [0, 2, 1])
71+
// | ^ |
72+
// | | +-- reshapeOp
73+
// | | ^ |
74+
// | | | +-- (maybe convert_layout)
75+
// | | +-- tmemLoad
76+
// -----------------------------------------------------------------------
77+
78+
// Starting from the split source, peel off convert_layouts if any.
79+
Value src = stripConvertLayout(splitOp.getSrc());
5180
auto transOp = src.getDefiningOp<TransOp>();
5281
if (!transOp || transOp.getOrder() != ArrayRef<int>({0, 2, 1}))
5382
return failure();
5483
auto reshapeOp = transOp.getSrc().getDefiningOp<ReshapeOp>();
5584
if (!reshapeOp)
5685
return failure();
57-
auto shape = reshapeOp.getResult().getType().getShape();
58-
if (shape[0] != reshapeOp.getSrc().getType().getShape()[0])
59-
return failure();
60-
auto tmemLoad = reshapeOp.getSrc().getDefiningOp<TMEMLoadOp>();
86+
87+
// Peel off convert_layouts *below* the reshape as well. This is required
88+
// for the recursive case where the producer of the reshape is the result
89+
// of an earlier optimisation pass (i.e. a convert_layout of a previous
90+
// tmem_load).
91+
Value reshapeSrc = stripConvertLayout(reshapeOp.getSrc());
92+
auto tmemLoad = reshapeSrc.getDefiningOp<TMEMLoadOp>();
6193
if (!tmemLoad)
6294
return failure();
63-
// We found a tmem_load that is split on the N dimension. We can split it
64-
// into multiple tmem_loads.
95+
96+
auto shape = reshapeOp.getResult().getType().getShape();
97+
// Ensure M dimension is preserved by the reshape.
98+
if (shape[0] != cast<RankedTensorType>(reshapeSrc.getType()).getShape()[0])
99+
return failure();
65100
int mDim = getShapePerCTA(tmemLoad.getSrc().getType())[0];
66101
// TODO: enable other M cases. (the layout is a bit more complex).
67102
if (mDim != 128)
68103
return failure();
69104
int splitNSize = shape[2];
70105
if (splitNSize < 8)
71106
return failure();
72-
Value tmem = tmemLoad.getSrc();
107+
108+
// Create the two TMEM subslices and their corresponding loads.
109+
Value tmem = tmemLoad.getSrc(); // Could itself be a subslice.
73110
int numWarps = ttg::lookupNumWarps(tmemLoad);
74111
rewriter.setInsertionPoint(tmemLoad);
75-
// First slice.
76-
Value subSlice0 =
77-
rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem, 0, splitNSize);
78-
Attribute distLayout = getTmemCompatibleLayout(
79-
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
80-
RankedTensorType newLoadType = RankedTensorType::get(
81-
splitOp.getOutLHS().getType().getShape(),
82-
splitOp.getOutLHS().getType().getElementType(), distLayout);
83-
auto load0 =
84-
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice0);
85-
auto cvt0 = rewriter.create<ttg::ConvertLayoutOp>(
86-
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load0);
87-
// Second slice.
88-
Value subSlice1 = rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem,
89-
splitNSize, splitNSize);
90-
auto load1 =
91-
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice1);
92-
auto cvt1 = rewriter.create<ttg::ConvertLayoutOp>(
93-
tmemLoad.getLoc(), splitOp.getOutRHS().getType(), load1);
112+
113+
auto createSliceLoad =
114+
[&](int64_t nOffset) -> std::pair<TMEMLoadOp, ttg::ConvertLayoutOp> {
115+
// Generate the subslice op.
116+
Value subSlice = rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem,
117+
nOffset, splitNSize);
118+
119+
// Choose a layout compatible with the slice size.
120+
Attribute distLayout = getTmemCompatibleLayout(
121+
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
122+
123+
RankedTensorType newLoadType = RankedTensorType::get(
124+
splitOp.getOutLHS().getType().getShape(),
125+
splitOp.getOutLHS().getType().getElementType(), distLayout);
126+
127+
// Generate the load and convert_layout back to the original layout.
128+
auto load =
129+
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice);
130+
auto cvt = rewriter.create<ttg::ConvertLayoutOp>(
131+
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load);
132+
133+
return {load, cvt};
134+
};
135+
136+
auto [load0, cvt0] = createSliceLoad(/*nOffset=*/0);
137+
auto [load1, cvt1] = createSliceLoad(/*nOffset=*/splitNSize);
94138
rewriter.replaceOp(splitOp, {cvt0, cvt1});
95139
return success();
96140
}

python/src/ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,10 @@ void init_triton_ir(py::module &&m) {
756756
[](TritonOpBuilder &self, int32_t value) {
757757
return self.getBuilder().getI32IntegerAttr(value);
758758
})
759+
.def("get_string_attr",
760+
[](TritonOpBuilder &self, std::string value) -> Attribute {
761+
return self.getBuilder().getStringAttr(value);
762+
})
759763
// Use arith.ConstantOp to create constants
760764
// Constants
761765
.def("get_int1",

0 commit comments

Comments
 (0)