Skip to content

Commit 79841bd

Browse files
authored
[NFI]: Improvements to the materialize block ptr pass (#5065)
This PR refactors the materialize block pointer pass to improve code organization and type safety. The changes replace generic operation handling with templated type-specific methods and remove helper functions that are no longer needed. --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 56c47fe commit 79841bd

File tree

1 file changed

+113
-123
lines changed

1 file changed

+113
-123
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 113 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
44
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
55
#include "intel/include/Utils/Utility.h"
6+
#include "mlir/IR/BuiltinAttributes.h"
67
#include "mlir/IR/Value.h"
78
#include "mlir/IR/Visitors.h"
8-
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "llvm/ADT/STLExtras.h"
10-
#include "llvm/ADT/TypeSwitch.h"
1110
#include "llvm/Support/Debug.h"
1211
#include <optional>
1312

@@ -36,137 +35,129 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3635
TritonIntelGPUMaterializeBlockPointerPass>::
3736
TritonIntelGPUMaterializeBlockPointerBase;
3837

39-
static Value getPointerFromOp(Operation *op) {
40-
return TypeSwitch<Operation *, Value>(op)
41-
.Case<tt::LoadOp, tt::StoreOp>([](auto op) { return op.getPtr(); })
42-
.Default([&](auto) {
43-
llvm_unreachable(
44-
+("Invalid operation: " + op->getName().getStringRef())
45-
.str()
46-
.c_str());
47-
return Value{};
48-
});
49-
}
50-
5138
void runOnOperation() override {
5239
ModuleOp mod = getOperation();
5340
if (!mod->hasAttr(
5441
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName()))
5542
return;
5643

5744
tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
58-
5945
MLIRContext *context = &getContext();
60-
mod.walk([&](Operation *op) {
61-
if (!isa<tt::LoadOp, tt::StoreOp>(op)) {
62-
return;
63-
}
64-
LDBG("Considering op: " << *op);
46+
mod.walk(
47+
[&](tt::LoadOp op) { return visit(op, axisInfoAnalysis, context); });
48+
mod.walk(
49+
[&](tt::StoreOp op) { return visit(op, axisInfoAnalysis, context); });
50+
}
6551

66-
Value ptr = getPointerFromOp(op);
67-
if (!tt::isTensorPointerType(ptr.getType()))
68-
return MaterializeTensorOfPointers(op, axisInfoAnalysis);
52+
private:
53+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
54+
OpType, tt::LoadOp, tt::StoreOp>::value>>
55+
void visit(OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
56+
MLIRContext *context) const {
57+
LDBG("Considering op: " << *op);
6958

70-
// Find the make tensor ptr operation that created the base ptr.
71-
std::optional<tt::MakeTensorPtrOp> defOp =
72-
tt::intel::findDefiningMakeTensorPtrOp(ptr);
73-
if (!defOp) {
74-
LDBG("Could not find make tensor ptr op for: " << *op);
75-
return;
76-
}
59+
Value ptr = op.getPtr();
60+
if (!tt::isTensorPointerType(ptr.getType()))
61+
return MaterializeTensorOfPointers(op, axisInfoAnalysis);
7762

78-
tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
79-
LDBG("Make tensor ptr op: " << makeTensorPtrOp);
63+
// Find the make tensor ptr operation that created the base ptr.
64+
std::optional<tt::MakeTensorPtrOp> defOp =
65+
tt::intel::findDefiningMakeTensorPtrOp(ptr);
66+
if (!defOp) {
67+
LDBG("Could not find make tensor ptr op for: " << *op);
68+
return;
69+
}
8070

81-
Operation::operand_range shape = makeTensorPtrOp.getShape();
82-
unsigned rank = shape.size();
83-
LDBG("Rank: " << rank);
84-
if (rank == 1)
85-
return;
71+
tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
72+
LDBG("Make tensor ptr op: " << makeTensorPtrOp);
8673

87-
if (!satisfies2DBlockReadAlignment(op, axisInfoAnalysis)) {
88-
LDBG("Alignment checks failed for: " << *op);
89-
return;
90-
}
74+
Operation::operand_range shape = makeTensorPtrOp.getShape();
75+
unsigned rank = shape.size();
76+
LDBG("Rank: " << rank);
77+
if (rank == 1)
78+
return;
9179

92-
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
93-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
94-
unsigned elementWidth = tensorType.getElementTypeBitWidth();
95-
LDBG("elementWidth: " << elementWidth);
80+
if (!satisfies2DBlockReadAlignment(op, axisInfoAnalysis)) {
81+
LDBG("Alignment checks failed for: " << *op);
82+
return;
83+
}
9684

97-
Operation::operand_range strides = makeTensorPtrOp.getStrides();
98-
std::optional<unsigned> strideOneDim = getStrideOneDim(makeTensorPtrOp);
99-
assert((strideOneDim && strideOneDim.value() < strides.size()) &&
100-
"Expected strideOneDim to be set and less than strides.size()");
101-
unsigned strideOneDimVal = strideOneDim.value();
85+
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
86+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
87+
unsigned elementWidth = tensorType.getElementTypeBitWidth();
88+
LDBG("elementWidth: " << elementWidth);
10289

103-
if (strideOneDimVal == rank - 2 && elementWidth == 8) {
104-
// TODO: column major layout w/ fp8 has performance regression
90+
Operation::operand_range strides = makeTensorPtrOp.getStrides();
91+
std::optional<unsigned> strideOneDim = getStrideOneDim(makeTensorPtrOp);
92+
assert((strideOneDim && strideOneDim.value() < strides.size()) &&
93+
"Expected strideOneDim to be set and less than strides.size()");
94+
unsigned strideOneDimVal = strideOneDim.value();
95+
96+
if (strideOneDimVal == rank - 2 && elementWidth == 8) {
97+
// TODO: column major layout w/ fp8 has performance regression
98+
return;
99+
}
100+
101+
if (strideOneDimVal >= (rank - 2)) {
102+
// HW 2D block read instruction only supports contiguous access.
103+
Value fastChangeStride = strides[strideOneDimVal];
104+
if (!tt::intel::isConstant(fastChangeStride, 1))
105105
return;
106-
}
107106

108-
if (strideOneDimVal >= (rank - 2)) {
109-
// HW 2D block read instruction only supports contiguous access.
110-
Value fastChangeStride = strides[strideOneDimVal];
111-
if (!tt::intel::isConstant(fastChangeStride, 1))
112-
return;
107+
// Across Intel platforms, the strictest pitch restriction is to be a
108+
// multiple of OWord(128 bits).
109+
Value pitch =
110+
strides[(strideOneDimVal == rank - 1) ? rank - 2 : rank - 1];
111+
LDBG("Pitch: " << pitch);
112+
if (!ttgi::isDivisible(pitch, llvm::divideCeil(128, elementWidth)))
113+
return;
113114

114-
// Across Intel platforms, the strictest pitch restriction is to be a
115-
// multiple of OWord(128 bits).
116-
Value pitch =
117-
strides[(strideOneDimVal == rank - 1) ? rank - 2 : rank - 1];
118-
LDBG("Pitch: " << pitch);
119-
if (!ttgi::isDivisible(pitch, llvm::divideCeil(128, elementWidth)))
115+
const bool isRowMajor = (strideOneDimVal == rank - 1);
116+
std::optional<ttg::DotOperandEncodingAttr> dotLayout = getDotLayout(op);
117+
if (dotLayout) {
118+
// Check if the load is being used by a tt.dot operation, and if so is
119+
// this the first operand and is it a transposed row major matrix. If
120+
// so, skip the block ptr attribute as performance is worse than if we
121+
// remove the tensor pointer.
122+
LDBG("dotLayout: " << *dotLayout);
123+
auto opIdx =
124+
static_cast<ttgi::DpasEncodingAttr::OpIdx>(dotLayout->getOpIdx());
125+
auto dotOrder = tt::gpu::getThreadOrder(tensorType);
126+
const bool valueRowMajor = (dotOrder[0] == 1 && dotOrder[1] == 0);
127+
if (opIdx == ttgi::DpasEncodingAttr::OpIdx::OperandA &&
128+
valueRowMajor ^ isRowMajor) {
129+
LDBG("Skipping block pointer attribute for transposed A matrix in "
130+
"dot operation");
120131
return;
121-
122-
const bool isRowMajor = (strideOneDimVal == rank - 1);
123-
std::optional<ttg::DotOperandEncodingAttr> dotLayout = getDotLayout(op);
124-
if (dotLayout) {
125-
// Check if the load is being used by a tt.dot operation, and if so is
126-
// this the first operand and is it a transposed row major matrix. If
127-
// so, skip the block ptr attribute as performance is worse than if we
128-
// remove the tensor pointer.
129-
LDBG("dotLayout: " << *dotLayout);
130-
auto opIdx =
131-
static_cast<ttgi::DpasEncodingAttr::OpIdx>(dotLayout->getOpIdx());
132-
auto dotOrder = tt::gpu::getThreadOrder(tensorType);
133-
const bool valueRowMajor = (dotOrder[0] == 1 && dotOrder[1] == 0);
134-
if (opIdx == ttgi::DpasEncodingAttr::OpIdx::OperandA &&
135-
valueRowMajor ^ isRowMajor) {
136-
LDBG("Skipping block pointer attribute for transposed A matrix in "
137-
"dot operation");
138-
return;
139-
}
140132
}
141-
142-
op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
143-
StringAttr::get(context,
144-
isRowMajor ? "row_major" : "column_major"));
145133
}
146-
});
134+
135+
op->setAttr(
136+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
137+
StringAttr::get(context, isRowMajor ? "row_major" : "column_major"));
138+
}
147139
}
148140

149-
private:
141+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
142+
OpType, tt::LoadOp, tt::StoreOp>::value>>
150143
void MaterializeTensorOfPointers(
151-
Operation *op,
152-
tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
153-
MLIRContext *context = op->getContext();
154-
Value ptr = getPointerFromOp(op);
144+
OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
145+
if constexpr (std::is_same_v<OpType, tt::LoadOp>) {
146+
if (op.getMask()) {
147+
LDBG("Load op has mask, skip block IO attribute");
148+
return;
149+
}
150+
}
151+
152+
Value ptr = op.getPtr();
155153
assert(!tt::isTensorPointerType(ptr.getType()) &&
156154
"Expected pointer refer to a tensor.");
157155

158156
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
159157
if (!tensorTy)
160158
return;
161159

162-
LDBG("Considering tensor of pointer of memory accessing op: " << *op);
163-
164-
if (auto loadOp = dyn_cast<tt::LoadOp>(*op)) {
165-
if (loadOp.getMask()) {
166-
LDBG("Load op has mask, skip block IO attribute");
167-
return;
168-
}
169-
}
160+
LDBG("Considering tensor of pointer of memory accessing op: " << op);
170161

171162
// The axis info gives the information about the value of the indices
172163
// tensor. For example, if the indices tensor is tensor<8x16xi32> and
@@ -187,60 +178,58 @@ struct TritonIntelGPUMaterializeBlockPointerPass
187178
}
188179

189180
// Determine if LoadOp is row-major or column-major.
190-
auto isMajor = [&](unsigned fastChangeDim) {
181+
auto isMajor = [](RankedTensorType tensorTy, unsigned fastChangeDim,
182+
const tt::AxisInfo &axisInfo) {
191183
assert((fastChangeDim == 0 || fastChangeDim == 1) &&
192184
"fastChangeDim is expected to be 0 or 1");
193185
const unsigned otherDim = !fastChangeDim;
194186
// Limit to full row being contiguous.
195-
if (axisInfo->getContiguity(fastChangeDim) !=
187+
if (axisInfo.getContiguity(fastChangeDim) !=
196188
tensorTy.getDimSize(fastChangeDim)) {
197189
LDBG("Found non-contiguous row: "
198-
<< axisInfo->getContiguity(fastChangeDim));
190+
<< axisInfo.getContiguity(fastChangeDim));
199191
return false;
200192
}
201193

202194
// Value -1 is used to represent the unknown stride.
203-
if (axisInfo->getStride(otherDim) < 0) {
204-
LDBG("Found unknown stride: " << axisInfo->getStride(otherDim));
195+
if (axisInfo.getStride(otherDim) < 0) {
196+
LDBG("Found unknown stride: " << axisInfo.getStride(otherDim));
205197
return false;
206198
}
207199

208200
// Surface pitch is required to be 16 bytes aligned.
209201
Type elemTy =
210202
cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
211203
unsigned elemSizeInBytes = elemTy.getIntOrFloatBitWidth() / 8;
212-
if ((axisInfo->getStride(otherDim) * elemSizeInBytes) % 16 != 0) {
204+
if ((axisInfo.getStride(otherDim) * elemSizeInBytes) % 16 != 0) {
213205
LDBG("Found Non 16 bytes aligned stride: "
214-
<< axisInfo->getStride(otherDim));
206+
<< axisInfo.getStride(otherDim));
215207
return false;
216208
}
217209

218210
// Base pointer can be compensate by the offset and base width, where they
219211
// each has restriction that it has to be 4 bytes aligned.
220-
if (axisInfo->getDivisibility(fastChangeDim) % 4 != 0) {
221-
LDBG(
222-
"Found Non 4 bytes aligned base: " << axisInfo->getDivisibility(1));
212+
if (axisInfo.getDivisibility(fastChangeDim) % 4 != 0) {
213+
LDBG("Found Non 4 bytes aligned base: " << axisInfo.getDivisibility(1));
223214
return false;
224215
}
225216

226217
return true;
227218
};
228219

229-
// Check if loadOp is row major, i.e., fast changing dimension is one.
230-
if (isMajor(1 /*fastChangeDim*/)) {
231-
LDBG("Setting row_major attribute\n");
220+
const bool isRowMajor = isMajor(tensorTy, 1 /*fastChangeDim*/, *axisInfo);
221+
if (isRowMajor)
232222
op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
233-
StringAttr::get(context, "row_major"));
234-
}
235-
236-
// TODO: set column_major attribute
223+
StringAttr::get(op.getContext(), "row_major"));
237224
}
238225

239226
// Return the load layout if it is a dot layout. If it is not, check if the
240227
// load result is converted to a dot layout. If so, return the dot layout,
241228
// otherwise return nullopt.
242-
std::optional<ttg::DotOperandEncodingAttr> getDotLayout(Operation *op) const {
243-
Value ptr = getPointerFromOp(op);
229+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
230+
OpType, tt::LoadOp, tt::StoreOp>::value>>
231+
std::optional<ttg::DotOperandEncodingAttr> getDotLayout(OpType op) const {
232+
Value ptr = op.getPtr();
244233
if (!tt::isTensorPointerType(ptr.getType()))
245234
return std::nullopt;
246235

@@ -294,10 +283,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass
294283
return strideOneDim;
295284
}
296285

286+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
287+
OpType, tt::LoadOp, tt::StoreOp>::value>>
297288
bool satisfies2DBlockReadAlignment(
298-
Operation *op,
299-
tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
300-
Value ptr = getPointerFromOp(op);
289+
OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
290+
Value ptr = op.getPtr();
301291
assert(tt::isTensorPointerType(ptr.getType()) &&
302292
"Expected a ptr to a tensor of ptrs.");
303293

0 commit comments

Comments
 (0)