Skip to content

Commit 979301f

Browse files
authored
Use stride instead of order to determine block attr (#2349)
Per the Triton slack, `order` is unused on architecture below Hopper. But more importantly, order provides information that stride already has. In fact, order can be completely different from stride (i.e. wrong) and we still generate correct code. I think it is better to use the stride assuming the logic I added here makes sense. Note this depends on #2348, I'd like to land the debug logging separately, so we have it even if we decide to modify this approach. It was very useful in debugging this problem. cc #2347
1 parent 19527ac commit 979301f

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,27 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5151
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
5252
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
5353
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
54-
ArrayRef<int32_t> order = makeTensorPtrOp.getOrder();
55-
unsigned rank = order.size();
54+
Operation::operand_range shape = makeTensorPtrOp.getShape();
55+
unsigned rank = shape.size();
5656
LDBG("Rank: " << rank);
5757
if (rank == 1)
5858
return;
5959

60-
unsigned fastChangeDim = order[0];
60+
Operation::operand_range strides = makeTensorPtrOp.getStrides();
61+
int fastChangeDim = -1;
62+
for (size_t i = 0; i < strides.size(); ++i) {
63+
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
64+
fastChangeDim = i;
65+
break;
66+
}
67+
}
68+
6169
LDBG("Fast change dim: " << fastChangeDim);
62-
if (fastChangeDim >= (rank - 2)) {
63-
Operation::operand_range strides = makeTensorPtrOp.getStrides();
70+
if (fastChangeDim < 0) {
71+
return;
72+
}
6473

74+
if (fastChangeDim >= (rank - 2)) {
6575
// HW 2D block read instruction only supports contiguous access.
6676
Value fastChangeStride = strides[fastChangeDim];
6777
LLVM_DEBUG({
@@ -77,7 +87,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
7787
Value pitch =
7888
strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
7989
LDBG("Pitch: " << pitch);
80-
if (!ttgi::isDivisible(pitch, 64 / tensorType.getElementTypeBitWidth()))
90+
if (!ttgi::isDivisible(pitch,
91+
128 / tensorType.getElementTypeBitWidth()))
8192
return;
8293

8394
loadOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),

0 commit comments

Comments
 (0)