Skip to content

Commit bf881ed

Browse files
committed
Use block load attribute to remove duplicate logic from MaterializeBlockPointer pass
1 parent 414eba6 commit bf881ed

File tree

1 file changed

+26
-56
lines changed

1 file changed

+26
-56
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace {
3333
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434
/// - its pitch is not divisible by Qword bitwidth
3535
/// - it is not contiguous in memory
36-
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
36+
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp) {
3737
LDBG("Considering removal of: " << op);
3838
if (!op->getParentOfType<ModuleOp>()->hasAttr(
3939
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
@@ -52,55 +52,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
5252
"by load or store op with DPAS layout");
5353
return true;
5454
}
55-
56-
TypedValue<triton::PointerType> base = op.getBase();
57-
Operation::operand_range shape = op.getShape();
58-
unsigned rank = shape.size();
59-
assert(rank > 1 && "Expecting tensor with rank > 1");
60-
Operation::operand_range strides = op.getStrides();
61-
Operation::operand_range offsets = op.getOffsets();
62-
ArrayRef<int32_t> order = op.getOrder();
63-
ArrayRef<int64_t> tensorShape = tensorType.getShape();
64-
65-
int fastChangeDim = -1;
66-
for (size_t i = 0; i < strides.size(); ++i) {
67-
if (ttgi::isConstant(strides[i], 1)) {
68-
fastChangeDim = i;
69-
break;
70-
}
71-
}
72-
73-
LDBG("fastChangeDim: " << fastChangeDim);
74-
if (fastChangeDim < 0) {
75-
LDBG("Marked for removal: fast changing dimension not found");
76-
return true;
77-
}
78-
79-
LDBG("Tensor type element type bit width: "
80-
<< tensorType.getElementTypeBitWidth());
81-
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
82-
// TODO: column major layout w/ fp8 has performance regression
83-
LDBG("Marked for removal: column major layout with fp8 element type");
84-
return true;
85-
}
86-
87-
// HW 2D block read instruction has restriction on pitch divisibility
88-
if (fastChangeDim >= (rank - 2)) {
89-
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
90-
LDBG("Pitch: " << pitch);
91-
// Across Intel platforms, the strictest pitch restriction is to be a
92-
// multiple of OWord(128 bits).
93-
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
94-
LDBG("Marked for removal: cannot use block read/write instructions");
95-
return true;
96-
}
97-
98-
return false;
99-
}
100-
101-
LDBG("Marked for removal: fall-trough");
102-
103-
return true;
55+
return false;
10456
}
10557

10658
/// The `RewritedInfo` struct is used to store information about a rewritten
@@ -715,10 +667,19 @@ class TritonIntelGPURewriteTensorPointerPass
715667
void runOnOperation() override {
716668
ModuleOp mod = getOperation();
717669

718-
auto usedByLoadOrStoreOp = [](Value val) {
719-
return llvm::any_of(val.getUsers(), [](Operation *user) {
720-
return isa<tt::LoadOp, tt::StoreOp>(user);
721-
});
670+
// TODO: do we need this attribute?
671+
auto usedByLoadOrStoreOp = [](Value val,
672+
const bool check_block_io_attribute = false) {
673+
return llvm::any_of(
674+
val.getUsers(), [check_block_io_attribute](Operation *user) {
675+
const bool is_load_or_store = isa<tt::LoadOp, tt::StoreOp>(user);
676+
if (check_block_io_attribute) {
677+
return user->hasAttr(
678+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName());
679+
} else {
680+
return is_load_or_store;
681+
}
682+
});
722683
};
723684

724685
auto markTensorPointerForRemoval =
@@ -738,8 +699,17 @@ class TritonIntelGPURewriteTensorPointerPass
738699
markTensorPointerForRemoval(op->getOperand(0),
739700
isa<tt::LoadOp, tt::StoreOp>(op));
740701
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
741-
for (auto arg : forOp.getInitArgs())
742-
markTensorPointerForRemoval(arg);
702+
for (auto [arg, blockArg] :
703+
llvm::zip(forOp.getInitArgs(),
704+
forOp.getBody()->getArguments().drop_front(
705+
forOp.getNumInductionVars()))) {
706+
if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp())) {
707+
constexpr bool check_block_io_attribute = true;
708+
markTensorPointerForRemoval(
709+
arg.getDefiningOp()->getResult(0),
710+
usedByLoadOrStoreOp(blockArg, check_block_io_attribute));
711+
}
712+
}
743713
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
744714
for (auto operand : yieldOp.getOperands())
745715
markTensorPointerForRemoval(operand);

0 commit comments

Comments
 (0)