Skip to content

Commit e064ea8

Browse files
authored
[NFI]: Refactor MaterializeBlockPointer.cpp (#3897)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent ddd205a commit e064ea8

File tree

1 file changed

+95
-49
lines changed

1 file changed

+95
-49
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 95 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -55,83 +55,54 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5555
"Expected 'loadOp' to load a tensor value.");
5656

5757
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr);
58+
if (!makeTensorPtrOp) {
59+
LDBG("Could not find make tensor ptr op.");
60+
return;
61+
}
5862
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
59-
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
60-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
61-
auto elementWidth = tensorType.getElementTypeBitWidth();
62-
LDBG("elementWidth: " << elementWidth);
6363

6464
Operation::operand_range shape = makeTensorPtrOp.getShape();
6565
unsigned rank = shape.size();
6666
LDBG("Rank: " << rank);
6767
if (rank == 1)
6868
return;
6969

70-
// We will compensate the offset of non-64 bytes aligned base to the
71-
// OffsetX and BaseWidth. The OffsetX and BaseWidth has extra restriction
72-
// that it has to be 4 bytes aligned.
73-
auto base = makeTensorPtrOp.getBase();
74-
if (!ttgi::isDivisible(base, 4)) {
75-
LDBG("Found Non 4 bytes aligned base: " << base);
76-
return;
77-
}
78-
79-
Operation::operand_range strides = makeTensorPtrOp.getStrides();
80-
int fastChangeDim = -1;
81-
for (size_t i = 0; i < strides.size(); ++i) {
82-
if (tt::intel::isConstant(strides[i], 1)) {
83-
fastChangeDim = i;
84-
break;
85-
}
86-
}
87-
88-
LDBG("Fast change dim: " << fastChangeDim);
89-
if (fastChangeDim < 0) {
90-
return;
91-
}
92-
93-
// Check the BaseWidth.
94-
Value BaseWidth = shape[fastChangeDim];
95-
if (!ttgi::isDivisible(BaseWidth, std::ceil(32 / elementWidth))) {
96-
LDBG("Found Non 4 bytes aligned BaseWidth: " << BaseWidth);
70+
if (!satisfies2DBlockReadAlignment(loadOp)) {
71+
LDBG("Alignment checks failed for: " << loadOp);
9772
return;
9873
}
9974

100-
// Check the OffsetX
101-
Operation::operand_range offsets = makeTensorPtrOp.getOffsets();
102-
Value OffsetX = offsets[fastChangeDim];
103-
if (!ttgi::isDivisible(OffsetX, std::ceil(32 / elementWidth))) {
104-
LDBG("Found Non 4 bytes aligned offsetX: " << OffsetX);
105-
return;
106-
}
75+
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
76+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
77+
unsigned elementWidth = tensorType.getElementTypeBitWidth();
78+
LDBG("elementWidth: " << elementWidth);
10779

108-
// TODO: Check the OffsetX from tl.advance
80+
Operation::operand_range strides = makeTensorPtrOp.getStrides();
81+
std::optional<unsigned> strideOneDim = getStrideOneDim(makeTensorPtrOp);
82+
assert((strideOneDim && strideOneDim.value() < strides.size()) &&
83+
"Expected strideOneDim to be set and less than strides.size()");
84+
unsigned strideOneDimVal = strideOneDim.value();
10985

110-
if (fastChangeDim == rank - 2 && elementWidth == 8) {
86+
if (strideOneDimVal == rank - 2 && elementWidth == 8) {
11187
// TODO: column major layout w/ fp8 has performance regression
11288
return;
11389
}
11490

115-
if (fastChangeDim >= (rank - 2)) {
91+
if (strideOneDimVal >= (rank - 2)) {
11692
// HW 2D block read instruction only supports contiguous access.
117-
Value fastChangeStride = strides[fastChangeDim];
118-
LLVM_DEBUG({
119-
DBGS() << "fastChangeStride: ";
120-
fastChangeStride.print(llvm::dbgs());
121-
llvm::dbgs() << "\n";
122-
});
93+
Value fastChangeStride = strides[strideOneDimVal];
12394
if (!tt::intel::isConstant(fastChangeStride, 1))
12495
return;
12596

12697
// Across Intel platforms, the strictest pitch restriction is to be a
12798
// multiple of OWord(128 bits).
12899
Value pitch =
129-
strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
100+
strides[(strideOneDimVal == rank - 1) ? rank - 2 : rank - 1];
130101
LDBG("Pitch: " << pitch);
131102
if (!ttgi::isDivisible(pitch, 128 / elementWidth))
132103
return;
133104

134-
const bool isRowMajor = fastChangeDim == rank - 1;
105+
const bool isRowMajor = (strideOneDimVal == rank - 1);
135106
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
136107
getDotLayout(loadOp);
137108
if (dotLayout) {
@@ -292,6 +263,81 @@ struct TritonIntelGPUMaterializeBlockPointerPass
292263

293264
return std::nullopt;
294265
}
266+
267+
std::optional<unsigned>
268+
getStrideOneDim(tt::MakeTensorPtrOp makeTensorPtrOp) const {
269+
assert(makeTensorPtrOp && "Expected a make tensor ptr op.");
270+
Operation::operand_range strides = makeTensorPtrOp.getStrides();
271+
std::optional<unsigned> strideOneDim{std::nullopt};
272+
for (auto [idx, stride] : llvm::enumerate(strides)) {
273+
if (!tt::intel::isConstant(stride, 1))
274+
continue;
275+
strideOneDim = idx;
276+
break;
277+
}
278+
return strideOneDim;
279+
}
280+
281+
bool satisfies2DBlockReadAlignment(tt::LoadOp loadOp) const {
282+
Value ptr = loadOp.getPtr();
283+
assert(tt::isTensorPointerType(ptr.getType()) &&
284+
"Expected a ptr to a tensor of ptrs.");
285+
assert(isa<RankedTensorType>(loadOp.getResult().getType()) &&
286+
"Expected 'loadOp' to load a ranked tensor value.");
287+
288+
// Find the make tensor ptr operation that created the base ptr for the load
289+
// operation.
290+
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr);
291+
assert(makeTensorPtrOp && "Expected a make tensor ptr op.");
292+
293+
Operation::operand_range shape = makeTensorPtrOp.getShape();
294+
if (shape.size() == 1)
295+
return false;
296+
297+
// Ensure the base ptr is 4-byte aligned.
298+
// Note: the HW requires the address to be 64-byte aligned, however we will
299+
// compensate by imposing restrictions on the offsetX and baseWidth.
300+
TypedValue<tt::PointerType> base = makeTensorPtrOp.getBase();
301+
if (!ttgi::isDivisible(base, 4)) {
302+
LDBG("Found non 4-bytes aligned base: " << base);
303+
return false;
304+
}
305+
306+
std::optional<unsigned> strideOneDim = getStrideOneDim(makeTensorPtrOp);
307+
if (!strideOneDim) {
308+
LDBG("Could not find stride one dimension in: " << makeTensorPtrOp);
309+
return false;
310+
}
311+
312+
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
313+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
314+
unsigned elementWidth = tensorType.getElementTypeBitWidth();
315+
unsigned strideOneDimVal = strideOneDim.value();
316+
LDBG("strideOneDim: " << strideOneDimVal);
317+
318+
// Analyze the shape of the stride one dimension to ensure it satisfies HW
319+
// constraints.
320+
Value baseWidth = shape[strideOneDimVal];
321+
unsigned divisor = std::ceil(32 / elementWidth);
322+
if (!ttgi::isDivisible(baseWidth, divisor)) {
323+
LDBG("baseWidth does not satisfies HW constraint: " << baseWidth);
324+
return false;
325+
}
326+
LDBG("baseWidth: " << baseWidth);
327+
328+
// Analyze the initial offset corresponding to the stride one dimension to
329+
// ensure it satisfies HW constraints.
330+
Value offset = makeTensorPtrOp.getOffsets()[strideOneDimVal];
331+
if (!ttgi::isDivisible(offset, divisor)) {
332+
LDBG("offset does not satisfies HW constraints: " << offset);
333+
return false;
334+
}
335+
LDBG("offset: " << offset);
336+
337+
// TODO: analyze tt.advance (issue #3762).
338+
339+
return true;
340+
}
295341
};
296342

297343
} // anonymous namespace

0 commit comments

Comments
 (0)