33#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
44#include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
55#include " intel/include/Utils/Utility.h"
6+ #include " mlir/IR/BuiltinAttributes.h"
67#include " mlir/IR/Value.h"
78#include " mlir/IR/Visitors.h"
8- #include " triton/Dialect/Triton/IR/Dialect.h"
99#include " llvm/ADT/STLExtras.h"
10- #include " llvm/ADT/TypeSwitch.h"
1110#include " llvm/Support/Debug.h"
1211#include < optional>
1312
@@ -36,137 +35,129 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3635 TritonIntelGPUMaterializeBlockPointerPass>::
3736 TritonIntelGPUMaterializeBlockPointerBase;
3837
39- static Value getPointerFromOp (Operation *op) {
40- return TypeSwitch<Operation *, Value>(op)
41- .Case <tt::LoadOp, tt::StoreOp>([](auto op) { return op.getPtr (); })
42- .Default ([&](auto ) {
43- llvm_unreachable (
44- +(" Invalid operation: " + op->getName ().getStringRef ())
45- .str ()
46- .c_str ());
47- return Value{};
48- });
49- }
50-
5138 void runOnOperation () override {
5239 ModuleOp mod = getOperation ();
5340 if (!mod->hasAttr (
5441 ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName ()))
5542 return ;
5643
5744 tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis (mod);
58-
5945 MLIRContext *context = &getContext ();
60- mod.walk ([&](Operation *op) {
61- if (!isa< tt::LoadOp, tt::StoreOp>( op)) {
62- return ;
63- }
64- LDBG ( " Considering op: " << *op);
46+ mod.walk (
47+ [&]( tt::LoadOp op) { return visit (op, axisInfoAnalysis, context); });
48+ mod. walk (
49+ [&](tt::StoreOp op) { return visit (op, axisInfoAnalysis, context); });
50+ }
6551
66- Value ptr = getPointerFromOp (op);
67- if (!tt::isTensorPointerType (ptr.getType ()))
68- return MaterializeTensorOfPointers (op, axisInfoAnalysis);
52+ private:
53+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
54+ OpType, tt::LoadOp, tt::StoreOp>::value>>
55+ void visit (OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
56+ MLIRContext *context) const {
57+ LDBG (" Considering op: " << *op);
6958
70- // Find the make tensor ptr operation that created the base ptr.
71- std::optional<tt::MakeTensorPtrOp> defOp =
72- tt::intel::findDefiningMakeTensorPtrOp (ptr);
73- if (!defOp) {
74- LDBG (" Could not find make tensor ptr op for: " << *op);
75- return ;
76- }
59+ Value ptr = op.getPtr ();
60+ if (!tt::isTensorPointerType (ptr.getType ()))
61+ return MaterializeTensorOfPointers (op, axisInfoAnalysis);
7762
78- tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
79- LDBG (" Make tensor ptr op: " << makeTensorPtrOp);
63+ // Find the make tensor ptr operation that created the base ptr.
64+ std::optional<tt::MakeTensorPtrOp> defOp =
65+ tt::intel::findDefiningMakeTensorPtrOp (ptr);
66+ if (!defOp) {
67+ LDBG (" Could not find make tensor ptr op for: " << *op);
68+ return ;
69+ }
8070
81- Operation::operand_range shape = makeTensorPtrOp.getShape ();
82- unsigned rank = shape.size ();
83- LDBG (" Rank: " << rank);
84- if (rank == 1 )
85- return ;
71+ tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
72+ LDBG (" Make tensor ptr op: " << makeTensorPtrOp);
8673
87- if (!satisfies2DBlockReadAlignment (op, axisInfoAnalysis)) {
88- LDBG (" Alignment checks failed for: " << *op);
89- return ;
90- }
74+ Operation::operand_range shape = makeTensorPtrOp.getShape ();
75+ unsigned rank = shape.size ();
76+ LDBG (" Rank: " << rank);
77+ if (rank == 1 )
78+ return ;
9179
92- auto ptrType = cast<tt::PointerType>(makeTensorPtrOp. getType ());
93- auto tensorType = cast<RankedTensorType>(ptrType. getPointeeType () );
94- unsigned elementWidth = tensorType. getElementTypeBitWidth () ;
95- LDBG ( " elementWidth: " << elementWidth);
80+ if (! satisfies2DBlockReadAlignment (op, axisInfoAnalysis)) {
81+ LDBG ( " Alignment checks failed for: " << *op );
82+ return ;
83+ }
9684
97- Operation::operand_range strides = makeTensorPtrOp.getStrides ();
98- std::optional<unsigned > strideOneDim = getStrideOneDim (makeTensorPtrOp);
99- assert ((strideOneDim && strideOneDim.value () < strides.size ()) &&
100- " Expected strideOneDim to be set and less than strides.size()" );
101- unsigned strideOneDimVal = strideOneDim.value ();
85+ auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
86+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
87+ unsigned elementWidth = tensorType.getElementTypeBitWidth ();
88+ LDBG (" elementWidth: " << elementWidth);
10289
103- if (strideOneDimVal == rank - 2 && elementWidth == 8 ) {
104- // TODO: column major layout w/ fp8 has performance regression
90+ Operation::operand_range strides = makeTensorPtrOp.getStrides ();
91+ std::optional<unsigned > strideOneDim = getStrideOneDim (makeTensorPtrOp);
92+ assert ((strideOneDim && strideOneDim.value () < strides.size ()) &&
93+ " Expected strideOneDim to be set and less than strides.size()" );
94+ unsigned strideOneDimVal = strideOneDim.value ();
95+
96+ if (strideOneDimVal == rank - 2 && elementWidth == 8 ) {
97+ // TODO: column major layout w/ fp8 has performance regression
98+ return ;
99+ }
100+
101+ if (strideOneDimVal >= (rank - 2 )) {
102+ // HW 2D block read instruction only supports contiguous access.
103+ Value fastChangeStride = strides[strideOneDimVal];
104+ if (!tt::intel::isConstant (fastChangeStride, 1 ))
105105 return ;
106- }
107106
108- if (strideOneDimVal >= (rank - 2 )) {
109- // HW 2D block read instruction only supports contiguous access.
110- Value fastChangeStride = strides[strideOneDimVal];
111- if (!tt::intel::isConstant (fastChangeStride, 1 ))
112- return ;
107+ // Across Intel platforms, the strictest pitch restriction is to be a
108+ // multiple of OWord(128 bits).
109+ Value pitch =
110+ strides[(strideOneDimVal == rank - 1 ) ? rank - 2 : rank - 1 ];
111+ LDBG (" Pitch: " << pitch);
112+ if (!ttgi::isDivisible (pitch, llvm::divideCeil (128 , elementWidth)))
113+ return ;
113114
114- // Across Intel platforms, the strictest pitch restriction is to be a
115- // multiple of OWord(128 bits).
116- Value pitch =
117- strides[(strideOneDimVal == rank - 1 ) ? rank - 2 : rank - 1 ];
118- LDBG (" Pitch: " << pitch);
119- if (!ttgi::isDivisible (pitch, llvm::divideCeil (128 , elementWidth)))
115+ const bool isRowMajor = (strideOneDimVal == rank - 1 );
116+ std::optional<ttg::DotOperandEncodingAttr> dotLayout = getDotLayout (op);
117+ if (dotLayout) {
118+ // Check if the load is being used by a tt.dot operation, and if so is
119+ // this the first operand and is it a transposed row major matrix. If
120+ // so, skip the block ptr attribute as performance is worse than if we
121+ // remove the tensor pointer.
122+ LDBG (" dotLayout: " << *dotLayout);
123+ auto opIdx =
124+ static_cast <ttgi::DpasEncodingAttr::OpIdx>(dotLayout->getOpIdx ());
125+ auto dotOrder = tt::gpu::getThreadOrder (tensorType);
126+ const bool valueRowMajor = (dotOrder[0 ] == 1 && dotOrder[1 ] == 0 );
127+ if (opIdx == ttgi::DpasEncodingAttr::OpIdx::OperandA &&
128+ valueRowMajor ^ isRowMajor) {
129+ LDBG (" Skipping block pointer attribute for transposed A matrix in "
130+ " dot operation" );
120131 return ;
121-
122- const bool isRowMajor = (strideOneDimVal == rank - 1 );
123- std::optional<ttg::DotOperandEncodingAttr> dotLayout = getDotLayout (op);
124- if (dotLayout) {
125- // Check if the load is being used by a tt.dot operation, and if so is
126- // this the first operand and is it a transposed row major matrix. If
127- // so, skip the block ptr attribute as performance is worse than if we
128- // remove the tensor pointer.
129- LDBG (" dotLayout: " << *dotLayout);
130- auto opIdx =
131- static_cast <ttgi::DpasEncodingAttr::OpIdx>(dotLayout->getOpIdx ());
132- auto dotOrder = tt::gpu::getThreadOrder (tensorType);
133- const bool valueRowMajor = (dotOrder[0 ] == 1 && dotOrder[1 ] == 0 );
134- if (opIdx == ttgi::DpasEncodingAttr::OpIdx::OperandA &&
135- valueRowMajor ^ isRowMajor) {
136- LDBG (" Skipping block pointer attribute for transposed A matrix in "
137- " dot operation" );
138- return ;
139- }
140132 }
141-
142- op->setAttr (ttgi::TritonIntelGPUDialect::getBlockIOAttrName (),
143- StringAttr::get (context,
144- isRowMajor ? " row_major" : " column_major" ));
145133 }
146- });
134+
135+ op->setAttr (
136+ ttgi::TritonIntelGPUDialect::getBlockIOAttrName (),
137+ StringAttr::get (context, isRowMajor ? " row_major" : " column_major" ));
138+ }
147139 }
148140
149- private:
141+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
142+ OpType, tt::LoadOp, tt::StoreOp>::value>>
150143 void MaterializeTensorOfPointers (
151- Operation *op,
152- tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
153- MLIRContext *context = op->getContext ();
154- Value ptr = getPointerFromOp (op);
144+ OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
145+ if constexpr (std::is_same_v<OpType, tt::LoadOp>) {
146+ if (op.getMask ()) {
147+ LDBG (" Load op has mask, skip block IO attribute" );
148+ return ;
149+ }
150+ }
151+
152+ Value ptr = op.getPtr ();
155153 assert (!tt::isTensorPointerType (ptr.getType ()) &&
156154 " Expected pointer refer to a tensor." );
157155
158156 auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType ());
159157 if (!tensorTy)
160158 return ;
161159
162- LDBG (" Considering tensor of pointer of memory accessing op: " << *op);
163-
164- if (auto loadOp = dyn_cast<tt::LoadOp>(*op)) {
165- if (loadOp.getMask ()) {
166- LDBG (" Load op has mask, skip block IO attribute" );
167- return ;
168- }
169- }
160+ LDBG (" Considering tensor of pointer of memory accessing op: " << op);
170161
171162 // The axis info gives the information about the value of the indices
172163 // tensor. For example, if the indices tensor is tensor<8x16xi32> and
@@ -187,60 +178,58 @@ struct TritonIntelGPUMaterializeBlockPointerPass
187178 }
188179
189180 // Determine if LoadOp is row-major or column-major.
190- auto isMajor = [&](unsigned fastChangeDim) {
181+ auto isMajor = [](RankedTensorType tensorTy, unsigned fastChangeDim,
182+ const tt::AxisInfo &axisInfo) {
191183 assert ((fastChangeDim == 0 || fastChangeDim == 1 ) &&
192184 " fastChangeDim is expected to be 0 or 1" );
193185 const unsigned otherDim = !fastChangeDim;
194186 // Limit to full row being contiguous.
195- if (axisInfo-> getContiguity (fastChangeDim) !=
187+ if (axisInfo. getContiguity (fastChangeDim) !=
196188 tensorTy.getDimSize (fastChangeDim)) {
197189 LDBG (" Found non-contiguous row: "
198- << axisInfo-> getContiguity (fastChangeDim));
190+ << axisInfo. getContiguity (fastChangeDim));
199191 return false ;
200192 }
201193
202194 // Value -1 is used to represent the unknown stride.
203- if (axisInfo-> getStride (otherDim) < 0 ) {
204- LDBG (" Found unknown stride: " << axisInfo-> getStride (otherDim));
195+ if (axisInfo. getStride (otherDim) < 0 ) {
196+ LDBG (" Found unknown stride: " << axisInfo. getStride (otherDim));
205197 return false ;
206198 }
207199
208200 // Surface pitch is required to be 16 bytes aligned.
209201 Type elemTy =
210202 cast<tt::PointerType>(tensorTy.getElementType ()).getPointeeType ();
211203 unsigned elemSizeInBytes = elemTy.getIntOrFloatBitWidth () / 8 ;
212- if ((axisInfo-> getStride (otherDim) * elemSizeInBytes) % 16 != 0 ) {
204+ if ((axisInfo. getStride (otherDim) * elemSizeInBytes) % 16 != 0 ) {
213205 LDBG (" Found Non 16 bytes aligned stride: "
214- << axisInfo-> getStride (otherDim));
206+ << axisInfo. getStride (otherDim));
215207 return false ;
216208 }
217209
218210 // Base pointer can be compensate by the offset and base width, where they
219211 // each has restriction that it has to be 4 bytes aligned.
220- if (axisInfo->getDivisibility (fastChangeDim) % 4 != 0 ) {
221- LDBG (
222- " Found Non 4 bytes aligned base: " << axisInfo->getDivisibility (1 ));
212+ if (axisInfo.getDivisibility (fastChangeDim) % 4 != 0 ) {
213+ LDBG (" Found Non 4 bytes aligned base: " << axisInfo.getDivisibility (1 ));
223214 return false ;
224215 }
225216
226217 return true ;
227218 };
228219
229- // Check if loadOp is row major, i.e., fast changing dimension is one.
230- if (isMajor (1 /* fastChangeDim*/ )) {
231- LDBG (" Setting row_major attribute\n " );
220+ const bool isRowMajor = isMajor (tensorTy, 1 /* fastChangeDim*/ , *axisInfo);
221+ if (isRowMajor)
232222 op->setAttr (ttgi::TritonIntelGPUDialect::getBlockIOAttrName (),
233- StringAttr::get (context, " row_major" ));
234- }
235-
236- // TODO: set column_major attribute
223+ StringAttr::get (op.getContext (), " row_major" ));
237224 }
238225
239226 // Return the load layout if it is a dot layout. If it is not, check if the
240227 // load result is converted to a dot layout. If so, return the dot layout,
241228 // otherwise return nullopt.
242- std::optional<ttg::DotOperandEncodingAttr> getDotLayout (Operation *op) const {
243- Value ptr = getPointerFromOp (op);
229+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
230+ OpType, tt::LoadOp, tt::StoreOp>::value>>
231+ std::optional<ttg::DotOperandEncodingAttr> getDotLayout (OpType op) const {
232+ Value ptr = op.getPtr ();
244233 if (!tt::isTensorPointerType (ptr.getType ()))
245234 return std::nullopt ;
246235
@@ -294,10 +283,11 @@ struct TritonIntelGPUMaterializeBlockPointerPass
294283 return strideOneDim;
295284 }
296285
286+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
287+ OpType, tt::LoadOp, tt::StoreOp>::value>>
297288 bool satisfies2DBlockReadAlignment (
298- Operation *op,
299- tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
300- Value ptr = getPointerFromOp (op);
289+ OpType op, tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis) const {
290+ Value ptr = op.getPtr ();
301291 assert (tt::isTensorPointerType (ptr.getType ()) &&
302292 " Expected a ptr to a tensor of ptrs." );
303293
0 commit comments