Skip to content

Commit 7c9a0f9

Browse files
committed
MaterializeBlockPointer fix for GEMM with 1st operand transposed
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent eeda8e9 commit 7c9a0f9

File tree

1 file changed

+47
-7
lines changed

1 file changed

+47
-7
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
using namespace mlir;
1414
namespace tt = mlir::triton;
15+
namespace ttg = mlir::triton::gpu;
1516
namespace ttgi = mlir::triton::gpu::intel;
1617

1718
namespace mlir::triton::gpu::intel {
@@ -37,7 +38,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3738
return;
3839

3940
MLIRContext *context = &getContext();
40-
mod.walk([context](tt::LoadOp loadOp) {
41+
mod.walk([context, this](tt::LoadOp loadOp) {
4142
LDBG("Considering op: " << loadOp);
4243

4344
Value ptr = loadOp.getPtr();
@@ -51,7 +52,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5152
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
5253
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
5354
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
54-
auto dotLayout = ttgi::getDotEncoding(tensorType);
5555

5656
Operation::operand_range shape = makeTensorPtrOp.getShape();
5757
unsigned rank = shape.size();
@@ -100,11 +100,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass
100100
return;
101101

102102
const bool isRowMajor = fastChangeDim == rank - 1;
103-
if (dotLayout) {
104-
// Check if the load is being used in a dot layout, and if so is this
105-
// the first op and is it a transposed row major matrix. If so, skip
106-
// the block ptr attribute as performance is worse than if we remove
107-
// the tensor pointer
103+
if (auto dotLayout = getDotLayout(loadOp)) {
104+
// Check if the load is being used by a tt.dot operation, and if so is
105+
// this the first operand and is it a transposed row major matrix. If
106+
// so, skip the block ptr attribute as performance is worse than if we
107+
// remove the tensor pointer.
108108
LDBG("dotLayout: " << *dotLayout);
109109
const unsigned opIdx = dotLayout->getOpIdx();
110110
auto dotOrder = dotLayout->getThreadOrder();
@@ -122,6 +122,46 @@ struct TritonIntelGPUMaterializeBlockPointerPass
122122
}
123123
});
124124
}
125+
126+
private:
127+
// Return the load layout if it is a dot layout. If it is not, check if the
128+
// load result is converted to a dot layout. If so, return the dot layout,
129+
// otherwise return nullopt.
130+
std::optional<ttg::DotOperandEncodingAttr>
131+
getDotLayout(tt::LoadOp loadOp) const {
132+
Value ptr = loadOp.getPtr();
133+
if (!tt::isTensorPointerType(ptr.getType()))
134+
return nullptr;
135+
136+
RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType());
137+
auto dotLayout = ttgi::getDotEncoding(tensorType);
138+
if (dotLayout)
139+
return dotLayout;
140+
141+
auto allUsersAreConvertOps = [](Operation::user_range users) {
142+
return llvm::all_of(users, [](Operation *user) {
143+
return isa<ttg::ConvertLayoutOp>(user);
144+
});
145+
};
146+
147+
auto allUserHaveIdenticalLayout = [](Operation::user_range users) {
148+
Attribute firstUserLayout =
149+
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
150+
return llvm::all_of(users, [&firstUserLayout](Operation *user) {
151+
return firstUserLayout ==
152+
cast<ttg::ConvertLayoutOp>(user).getType().getEncoding();
153+
});
154+
};
155+
156+
Operation::user_range users = loadOp->getUsers();
157+
if (allUsersAreConvertOps(users) && allUserHaveIdenticalLayout(users)) {
158+
Attribute firstUserLayout =
159+
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
160+
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
161+
}
162+
163+
return nullptr;
164+
}
125165
};
126166

127167
} // anonymous namespace

0 commit comments

Comments
 (0)