Skip to content

Commit db2a0e9

Browse files
committed
Use stride instead of order to determine block attr
1 parent 5b1c668 commit db2a0e9

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,28 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5151
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
5252
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
5353
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
54-
ArrayRef<int32_t> order = makeTensorPtrOp.getOrder();
55-
unsigned rank = order.size();
54+
Operation::operand_range shape = makeTensorPtrOp.getShape();
55+
unsigned rank = shape.size();
5656
LDBG("Rank: " << rank);
5757
if (rank == 1)
5858
return;
5959

60-
unsigned fastChangeDim = order[0];
60+
Operation::operand_range strides = makeTensorPtrOp.getStrides();
61+
int fastChangeDim = -1;
62+
for (size_t i = 0; i < strides.size(); i++) {
63+
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
64+
fastChangeDim = i;
65+
break;
66+
}
67+
}
6168
LDBG("Fast change dim: " << fastChangeDim);
69+
if (fastChangeDim < 0) {
70+
return;
71+
}
72+
ArrayRef<int32_t> order = makeTensorPtrOp.getOrder();
73+
74+
// unsigned fastChangeDim = order[0];
6275
if (fastChangeDim >= (rank - 2)) {
63-
Operation::operand_range strides = makeTensorPtrOp.getStrides();
6476

6577
// HW 2D block read instruction only supports contiguous access.
6678
Value fastChangeStride = strides[fastChangeDim];

0 commit comments

Comments
 (0)