1212
1313using namespace mlir ;
1414namespace tt = mlir::triton;
15+ namespace ttg = mlir::triton::gpu;
1516namespace ttgi = mlir::triton::gpu::intel;
1617
1718namespace mlir ::triton::gpu::intel {
@@ -37,7 +38,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3738 return ;
3839
3940 MLIRContext *context = &getContext ();
40- mod.walk ([context](tt::LoadOp loadOp) {
41+ mod.walk ([context, this ](tt::LoadOp loadOp) {
4142 LDBG (" Considering op: " << loadOp);
4243
4344 Value ptr = loadOp.getPtr ();
@@ -51,7 +52,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5152 LDBG (" Found make tensor ptr op: " << makeTensorPtrOp);
5253 auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
5354 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
54- auto dotLayout = ttgi::getDotEncoding (tensorType);
5555
5656 Operation::operand_range shape = makeTensorPtrOp.getShape ();
5757 unsigned rank = shape.size ();
@@ -100,11 +100,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass
100100 return ;
101101
102102 const bool isRowMajor = fastChangeDim == rank - 1 ;
103- if (dotLayout) {
104- // Check if the load is being used in a dot layout , and if so is this
105- // the first op and is it a transposed row major matrix. If so, skip
106- // the block ptr attribute as performance is worse than if we remove
107- // the tensor pointer
103+ if (auto dotLayout = getDotLayout (loadOp) ) {
104+ // Check if the load is being used by a tt. dot operation , and if so is
105+ // this the first operand and is it a transposed row major matrix. If
106+ // so, skip the block ptr attribute as performance is worse than if we
107+ // remove the tensor pointer.
108108 LDBG (" dotLayout: " << *dotLayout);
109109 const unsigned opIdx = dotLayout->getOpIdx ();
110110 auto dotOrder = dotLayout->getThreadOrder ();
@@ -122,6 +122,46 @@ struct TritonIntelGPUMaterializeBlockPointerPass
122122 }
123123 });
124124 }
125+
126+ private:
127+ // Return the load layout if it is a dot layout. If it is not, check if the
128+ // load result is converted to a dot layout. If so, return the dot layout,
129+ // otherwise return nullopt.
130+ std::optional<ttg::DotOperandEncodingAttr>
131+ getDotLayout (tt::LoadOp loadOp) const {
132+ Value ptr = loadOp.getPtr ();
133+ if (!tt::isTensorPointerType (ptr.getType ()))
134+ return nullptr ;
135+
136+ RankedTensorType tensorType = ttgi::getRankedTensorType (ptr.getType ());
137+ auto dotLayout = ttgi::getDotEncoding (tensorType);
138+ if (dotLayout)
139+ return dotLayout;
140+
141+ auto allUsersAreConvertOps = [](Operation::user_range users) {
142+ return llvm::all_of (users, [](Operation *user) {
143+ return isa<ttg::ConvertLayoutOp>(user);
144+ });
145+ };
146+
147+ auto allUserHaveIdenticalLayout = [](Operation::user_range users) {
148+ Attribute firstUserLayout =
149+ cast<ttg::ConvertLayoutOp>(*users.begin ()).getType ().getEncoding ();
150+ return llvm::all_of (users, [&firstUserLayout](Operation *user) {
151+ return firstUserLayout ==
152+ cast<ttg::ConvertLayoutOp>(user).getType ().getEncoding ();
153+ });
154+ };
155+
156+ Operation::user_range users = loadOp->getUsers ();
157+ if (allUsersAreConvertOps (users) && allUserHaveIdenticalLayout (users)) {
158+ Attribute firstUserLayout =
159+ cast<ttg::ConvertLayoutOp>(*users.begin ()).getType ().getEncoding ();
160+ return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
161+ }
162+
163+ return nullptr ;
164+ }
125165};
126166
127167} // anonymous namespace
0 commit comments