@@ -55,83 +55,54 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5555 " Expected 'loadOp' to load a tensor value." );
5656
5757 tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (ptr);
58+ if (!makeTensorPtrOp) {
59+ LDBG (" Could not find make tensor ptr op." );
60+ return ;
61+ }
5862 LDBG (" Found make tensor ptr op: " << makeTensorPtrOp);
59- auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
60- auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
61- auto elementWidth = tensorType.getElementTypeBitWidth ();
62- LDBG (" elementWidth: " << elementWidth);
6363
6464 Operation::operand_range shape = makeTensorPtrOp.getShape ();
6565 unsigned rank = shape.size ();
6666 LDBG (" Rank: " << rank);
6767 if (rank == 1 )
6868 return ;
6969
70- // We will compensate the offset of non-64 bytes aligned base to the
71- // OffsetX and BaseWidth. The OffsetX and BaseWidth has extra restriction
72- // that it has to be 4 bytes aligned.
73- auto base = makeTensorPtrOp.getBase ();
74- if (!ttgi::isDivisible (base, 4 )) {
75- LDBG (" Found Non 4 bytes aligned base: " << base);
76- return ;
77- }
78-
79- Operation::operand_range strides = makeTensorPtrOp.getStrides ();
80- int fastChangeDim = -1 ;
81- for (size_t i = 0 ; i < strides.size (); ++i) {
82- if (tt::intel::isConstant (strides[i], 1 )) {
83- fastChangeDim = i;
84- break ;
85- }
86- }
87-
88- LDBG (" Fast change dim: " << fastChangeDim);
89- if (fastChangeDim < 0 ) {
90- return ;
91- }
92-
93- // Check the BaseWidth.
94- Value BaseWidth = shape[fastChangeDim];
95- if (!ttgi::isDivisible (BaseWidth, std::ceil (32 / elementWidth))) {
96- LDBG (" Found Non 4 bytes aligned BaseWidth: " << BaseWidth);
70+ if (!satisfies2DBlockReadAlignment (loadOp)) {
71+ LDBG (" Alignment checks failed for: " << loadOp);
9772 return ;
9873 }
9974
100- // Check the OffsetX
101- Operation::operand_range offsets = makeTensorPtrOp.getOffsets ();
102- Value OffsetX = offsets[fastChangeDim];
103- if (!ttgi::isDivisible (OffsetX, std::ceil (32 / elementWidth))) {
104- LDBG (" Found Non 4 bytes aligned offsetX: " << OffsetX);
105- return ;
106- }
75+ auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
76+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
77+ unsigned elementWidth = tensorType.getElementTypeBitWidth ();
78+ LDBG (" elementWidth: " << elementWidth);
10779
108- // TODO: Check the OffsetX from tl.advance
80+ Operation::operand_range strides = makeTensorPtrOp.getStrides ();
81+ std::optional<unsigned > strideOneDim = getStrideOneDim (makeTensorPtrOp);
82+ assert ((strideOneDim && strideOneDim.value () < strides.size ()) &&
83+ " Expected strideOneDim to be set and less than strides.size()" );
84+ unsigned strideOneDimVal = strideOneDim.value ();
10985
110- if (fastChangeDim == rank - 2 && elementWidth == 8 ) {
86+ if (strideOneDimVal == rank - 2 && elementWidth == 8 ) {
11187 // TODO: column major layout w/ fp8 has performance regression
11288 return ;
11389 }
11490
115- if (fastChangeDim >= (rank - 2 )) {
91+ if (strideOneDimVal >= (rank - 2 )) {
11692 // HW 2D block read instruction only supports contiguous access.
117- Value fastChangeStride = strides[fastChangeDim];
118- LLVM_DEBUG ({
119- DBGS () << " fastChangeStride: " ;
120- fastChangeStride.print (llvm::dbgs ());
121- llvm::dbgs () << " \n " ;
122- });
93+ Value fastChangeStride = strides[strideOneDimVal];
12394 if (!tt::intel::isConstant (fastChangeStride, 1 ))
12495 return ;
12596
12697 // Across Intel platforms, the strictest pitch restriction is to be a
12798 // multiple of OWord(128 bits).
12899 Value pitch =
129- strides[(fastChangeDim == rank - 1 ) ? rank - 2 : rank - 1 ];
100+ strides[(strideOneDimVal == rank - 1 ) ? rank - 2 : rank - 1 ];
130101 LDBG (" Pitch: " << pitch);
131102 if (!ttgi::isDivisible (pitch, 128 / elementWidth))
132103 return ;
133104
134- const bool isRowMajor = fastChangeDim == rank - 1 ;
105+ const bool isRowMajor = (strideOneDimVal == rank - 1 ) ;
135106 std::optional<ttg::DotOperandEncodingAttr> dotLayout =
136107 getDotLayout (loadOp);
137108 if (dotLayout) {
@@ -292,6 +263,81 @@ struct TritonIntelGPUMaterializeBlockPointerPass
292263
293264 return std::nullopt ;
294265 }
266+
267+ std::optional<unsigned >
268+ getStrideOneDim (tt::MakeTensorPtrOp makeTensorPtrOp) const {
269+ assert (makeTensorPtrOp && " Expected a make tensor ptr op." );
270+ Operation::operand_range strides = makeTensorPtrOp.getStrides ();
271+ std::optional<unsigned > strideOneDim{std::nullopt };
272+ for (auto [idx, stride] : llvm::enumerate (strides)) {
273+ if (!tt::intel::isConstant (stride, 1 ))
274+ continue ;
275+ strideOneDim = idx;
276+ break ;
277+ }
278+ return strideOneDim;
279+ }
280+
281+ bool satisfies2DBlockReadAlignment (tt::LoadOp loadOp) const {
282+ Value ptr = loadOp.getPtr ();
283+ assert (tt::isTensorPointerType (ptr.getType ()) &&
284+ " Expected a ptr to a tensor of ptrs." );
285+ assert (isa<RankedTensorType>(loadOp.getResult ().getType ()) &&
286+ " Expected 'loadOp' to load a ranked tensor value." );
287+
288+ // Find the make tensor ptr operation that created the base ptr for the load
289+ // operation.
290+ tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (ptr);
291+ assert (makeTensorPtrOp && " Expected a make tensor ptr op." );
292+
293+ Operation::operand_range shape = makeTensorPtrOp.getShape ();
294+ if (shape.size () == 1 )
295+ return false ;
296+
297+ // Ensure the base ptr is 4-byte aligned.
298+ // Note: the HW requires the address to be 64-byte aligned, however we will
299+ // compensate by imposing restrictions on the offsetX and baseWidth.
300+ TypedValue<tt::PointerType> base = makeTensorPtrOp.getBase ();
301+ if (!ttgi::isDivisible (base, 4 )) {
302+ LDBG (" Found non 4-bytes aligned base: " << base);
303+ return false ;
304+ }
305+
306+ std::optional<unsigned > strideOneDim = getStrideOneDim (makeTensorPtrOp);
307+ if (!strideOneDim) {
308+ LDBG (" Could not find stride one dimension in: " << makeTensorPtrOp);
309+ return false ;
310+ }
311+
312+ auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
313+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
314+ unsigned elementWidth = tensorType.getElementTypeBitWidth ();
315+ unsigned strideOneDimVal = strideOneDim.value ();
316+ LDBG (" strideOneDim: " << strideOneDimVal);
317+
318+ // Analyze the shape of the stride one dimension to ensure it satisfies HW
319+ // constraints.
320+ Value baseWidth = shape[strideOneDimVal];
321+ unsigned divisor = std::ceil (32 / elementWidth);
322+ if (!ttgi::isDivisible (baseWidth, divisor)) {
323+ LDBG (" baseWidth does not satisfies HW constraint: " << baseWidth);
324+ return false ;
325+ }
326+ LDBG (" baseWidth: " << baseWidth);
327+
328+ // Analyze the initial offset corresponding to the stride one dimension to
329+ // ensure it satisfies HW constraints.
330+ Value offset = makeTensorPtrOp.getOffsets ()[strideOneDimVal];
331+ if (!ttgi::isDivisible (offset, divisor)) {
332+ LDBG (" offset does not satisfies HW constraints: " << offset);
333+ return false ;
334+ }
335+ LDBG (" offset: " << offset);
336+
337+ // TODO: analyze tt.advance (issue #3762).
338+
339+ return true ;
340+ }
295341};
296342
297343} // anonymous namespace
0 commit comments