Skip to content

Commit 93b2d49

Browse files
authored
[BACKEND] Enhance block store lowering (#4544)
The first step is to enable the memory analysis on the pointer value used by tt.store operation. In later, enhance the lowering of the tt.store based on the memory information as the lowering of tt.load. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent ecb52e4 commit 93b2d49

File tree

2 files changed

+78
-34
lines changed

2 files changed

+78
-34
lines changed

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
1818
%4 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot_b>>
1919
%5 = tt.load %3 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
2020
%6 = tt.load %4 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
21+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>, ttig.block_io = "row_major"}
22+
tt.store %3, %5 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
23+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>, ttig.block_io = "row_major"}
24+
tt.store %4, %6 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
2125

2226
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
2327
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0>, padding = 1 : i32, ttig.block_io = "column_major"}
2428
%7 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #dot_a>>
2529
%8 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf16, #dot_b>>
2630
%9 = tt.load %7 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
2731
%10 = tt.load %8 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
32+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
33+
tt.store %7, %9 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
34+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>, ttig.block_io = "column_major"}
35+
tt.store %8, %10 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
2836

2937
// COM: Non-constant stride on fast changing dim.
3038
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
@@ -33,6 +41,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
3341
%12 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %pitch], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf16, #dot_b>>
3442
%13 = tt.load %11 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
3543
%14 = tt.load %12 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
44+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
45+
tt.store %11, %13 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
46+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
47+
tt.store %12, %14 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
3648

3749
// COM: Non-64 divisible pitch.
3850
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
@@ -41,6 +53,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
4153
%16 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch_odd], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf16, #dot_b>>
4254
%17 = tt.load %15 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
4355
%18 = tt.load %16 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
56+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
57+
tt.store %15, %17 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
58+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
59+
tt.store %16, %18 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
4460

4561
// COM: Non 4 bytes aligned base.
4662
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
@@ -49,6 +65,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
4965
%20 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot_b>>
5066
%21 = tt.load %19 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
5167
%22 = tt.load %20 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
68+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
69+
tt.store %19, %21 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
70+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
71+
tt.store %20, %22 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
5272

5373
// COM: Non 4 bytes aligned baseWidth.
5474
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
@@ -57,6 +77,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
5777
%24 = tt.make_tensor_ptr %arg0, [%c0_i64, %c15_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot_b>>
5878
%25 = tt.load %23 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
5979
%26 = tt.load %24 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
80+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
81+
tt.store %23, %25 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
82+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
83+
tt.store %24, %26 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
6084

6185
// COM: Non 4 bytes aligned offsetX.
6286
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32}
@@ -65,6 +89,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", ttig.support_sg
6589
%28 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c15_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot_b>>
6690
%29 = tt.load %27 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
6791
%30 = tt.load %28 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>
92+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
93+
tt.store %27, %29 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<64x32xf16, #dot_a>>
94+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>}
95+
tt.store %28, %30 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x64xf16, #dot_b>>
96+
6897
tt.return
6998
}
7099
}
@@ -103,6 +132,8 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
103132
// COM: 4 bytes aligned base (value got from addptr, addi, muli), baseWidth and offsetX (value got from muli).
104133
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0>, padding = 1 : i32, ttig.block_io = "row_major"}
105134
%11 = tt.load %10 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
135+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>, ttig.block_io = "row_major"}
136+
tt.store %10, %11 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
106137
tt.return
107138
}
108139
}
@@ -130,6 +161,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
130161
%14 = tt.advance %12, [%1, %arg3] : <tensor<8x128xf32, #blocked>>
131162
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"}
132163
%15 = tt.load %14 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x128xf32, #blocked>>
164+
// CHECK: tt.store {{.*}} {boundaryCheck = array<i32: 1>, ttig.block_io = "row_major"}
165+
tt.store %14, %15 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<8x128xf32, #blocked>>
133166
scf.yield %12 : !tt.ptr<tensor<8x128xf32, #blocked>>
134167
}
135168
tt.return

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "mlir/IR/Visitors.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "llvm/ADT/STLExtras.h"
10+
#include "llvm/ADT/TypeSwitch.h"
1011
#include "llvm/Support/Debug.h"
1112
#include <optional>
1213

@@ -35,6 +36,18 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3536
TritonIntelGPUMaterializeBlockPointerPass>::
3637
TritonIntelGPUMaterializeBlockPointerBase;
3738

39+
static Value getPointerFromOp(Operation *op) {
40+
return TypeSwitch<Operation *, Value>(op)
41+
.Case<tt::LoadOp, tt::StoreOp>([](auto op) { return op.getPtr(); })
42+
.Default([&](auto) {
43+
llvm_unreachable(
44+
+("Invalid operation: " + op->getName().getStringRef())
45+
.str()
46+
.c_str());
47+
return Value{};
48+
});
49+
}
50+
3851
void runOnOperation() override {
3952
ModuleOp mod = getOperation();
4053
if (!mod->hasAttr(
@@ -44,21 +57,21 @@ struct TritonIntelGPUMaterializeBlockPointerPass
4457
tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
4558

4659
MLIRContext *context = &getContext();
47-
mod.walk([&](tt::LoadOp loadOp) {
48-
LDBG("Considering op: " << loadOp);
60+
mod.walk([&](Operation *op) {
61+
if (!isa<tt::LoadOp, tt::StoreOp>(op)) {
62+
return;
63+
}
64+
LDBG("Considering op: " << *op);
4965

50-
Value ptr = loadOp.getPtr();
66+
Value ptr = getPointerFromOp(op);
5167
if (!tt::isTensorPointerType(ptr.getType()))
52-
return MaterializeTensorOfPointers(loadOp, axisInfoAnalysis);
53-
54-
assert(isa<RankedTensorType>(loadOp.getResult().getType()) &&
55-
"Expected 'loadOp' to load a tensor value.");
68+
return MaterializeTensorOfPointers(op, axisInfoAnalysis);
5669

5770
// Find the make tensor ptr operation that created the base ptr.
5871
std::optional<tt::MakeTensorPtrOp> defOp =
5972
tt::intel::findDefiningMakeTensorPtrOp(ptr);
6073
if (!defOp) {
61-
LDBG("Could not find make tensor ptr op for: " << loadOp);
74+
LDBG("Could not find make tensor ptr op for: " << *op);
6275
return;
6376
}
6477

@@ -71,8 +84,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
7184
if (rank == 1)
7285
return;
7386

74-
if (!satisfies2DBlockReadAlignment(loadOp, axisInfoAnalysis)) {
75-
LDBG("Alignment checks failed for: " << loadOp);
87+
if (!satisfies2DBlockReadAlignment(op, axisInfoAnalysis)) {
88+
LDBG("Alignment checks failed for: " << *op);
7689
return;
7790
}
7891

@@ -107,8 +120,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
107120
return;
108121

109122
const bool isRowMajor = (strideOneDimVal == rank - 1);
110-
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
111-
getDotLayout(loadOp);
123+
std::optional<ttg::DotOperandEncodingAttr> dotLayout = getDotLayout(op);
112124
if (dotLayout) {
113125
// Check if the load is being used by a tt.dot operation, and if so is
114126
// this the first operand and is it a transposed row major matrix. If
@@ -127,31 +139,33 @@ struct TritonIntelGPUMaterializeBlockPointerPass
127139
}
128140
}
129141

130-
loadOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
131-
StringAttr::get(context, isRowMajor ? "row_major"
132-
: "column_major"));
142+
op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
143+
StringAttr::get(context,
144+
isRowMajor ? "row_major" : "column_major"));
133145
}
134146
});
135147
}
136148

137149
private:
138150
void MaterializeTensorOfPointers(
139-
tt::LoadOp loadOp,
151+
Operation *op,
140152
tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
141-
MLIRContext *context = loadOp.getContext();
142-
Value ptr = loadOp.getPtr();
153+
MLIRContext *context = op->getContext();
154+
Value ptr = getPointerFromOp(op);
143155
assert(!tt::isTensorPointerType(ptr.getType()) &&
144-
"Expected 'loadOp' to load a tensor value.");
156+
"Expected pointer refer to a tensor.");
145157

146158
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
147159
if (!tensorTy)
148160
return;
149161

150-
LDBG("Considering tensor of pointer load op: " << loadOp);
162+
LDBG("Considering tensor of pointer of memory accessing op: " << *op);
151163

152-
if (loadOp.getMask()) {
153-
LDBG("Load op has mask, skip block IO attribute");
154-
return;
164+
if (auto loadOp = dyn_cast<tt::LoadOp>(*op)) {
165+
if (loadOp.getMask()) {
166+
LDBG("Load op has mask, skip block IO attribute");
167+
return;
168+
}
155169
}
156170

157171
// The axis info gives the information about the value of the indices
@@ -215,8 +229,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
215229
// Check if loadOp is row major, i.e., fast changing dimension is one.
216230
if (isMajor(1 /*fastChangeDim*/)) {
217231
LDBG("Setting row_major attribute\n");
218-
loadOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
219-
StringAttr::get(context, "row_major"));
232+
op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
233+
StringAttr::get(context, "row_major"));
220234
}
221235

222236
// TODO: set column_major attribute
@@ -225,9 +239,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
225239
// Return the load layout if it is a dot layout. If it is not, check if the
226240
// load result is converted to a dot layout. If so, return the dot layout,
227241
// otherwise return nullopt.
228-
std::optional<ttg::DotOperandEncodingAttr>
229-
getDotLayout(tt::LoadOp loadOp) const {
230-
Value ptr = loadOp.getPtr();
242+
std::optional<ttg::DotOperandEncodingAttr> getDotLayout(Operation *op) const {
243+
Value ptr = getPointerFromOp(op);
231244
if (!tt::isTensorPointerType(ptr.getType()))
232245
return std::nullopt;
233246

@@ -254,7 +267,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
254267
});
255268
};
256269

257-
Operation::user_range users = loadOp->getUsers();
270+
Operation::user_range users = op->getUsers();
258271
if (!users.empty() && allUsersAreConvertOps(users) &&
259272
allUserHaveIdenticalLayout(users)) {
260273
Attribute firstUserLayout =
@@ -282,13 +295,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass
282295
}
283296

284297
bool satisfies2DBlockReadAlignment(
285-
tt::LoadOp loadOp,
298+
Operation *op,
286299
tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
287-
Value ptr = loadOp.getPtr();
300+
Value ptr = getPointerFromOp(op);
288301
assert(tt::isTensorPointerType(ptr.getType()) &&
289302
"Expected a ptr to a tensor of ptrs.");
290-
assert(isa<RankedTensorType>(loadOp.getResult().getType()) &&
291-
"Expected 'loadOp' to load a ranked tensor value.");
292303

293304
// Find the make tensor ptr operation that created the base ptr for the load
294305
// operation.
@@ -350,7 +361,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
350361
}
351362
LDBG("offset: " << offset);
352363

353-
Region *loadRgn = loadOp->getParentRegion();
364+
Region *loadRgn = op->getParentRegion();
354365
Region *makeTensorPtrRgn = makeTensorPtrOp->getParentRegion();
355366
bool inSameRegion = (loadRgn == makeTensorPtrRgn);
356367
if (inSameRegion)

0 commit comments

Comments
 (0)