44#include " mlir/Dialect/Arith/IR/Arith.h"
55#include " mlir/IR/Visitors.h"
66#include " triton/Analysis/Utility.h"
7+ #include " llvm/Support/Debug.h"
8+
9+ #define DEBUG_TYPE " tritonintelgpu-materialize-block-pointer"
10+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
11+ #define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
712
813using namespace mlir ;
914namespace tt = mlir::triton;
@@ -33,6 +38,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3338
3439 MLIRContext *context = &getContext ();
3540 mod.walk ([context](tt::LoadOp loadOp) {
41+ LDBG (" Considering op: " << loadOp);
42+
3643 Value ptr = loadOp.getPtr ();
3744 if (!tt::isTensorPointerType (ptr.getType ()))
3845 return ;
@@ -41,26 +48,35 @@ struct TritonIntelGPUMaterializeBlockPointerPass
4148 " Expected 'loadOp' to load a tensor value." );
4249
4350 tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (ptr);
51+ LDBG (" Found make tensor ptr op: " << makeTensorPtrOp);
4452 auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
4553 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
4654 ArrayRef<int32_t > order = makeTensorPtrOp.getOrder ();
4755 unsigned rank = order.size ();
56+ LDBG (" Rank: " << rank);
4857 if (rank == 1 )
4958 return ;
5059
5160 unsigned fastChangeDim = order[0 ];
61+ LDBG (" Fast change dim: " << fastChangeDim);
5262 if (fastChangeDim >= (rank - 2 )) {
5363 Operation::operand_range strides = makeTensorPtrOp.getStrides ();
5464
5565 // HW 2D block read instruction only supports contiguous access.
5666 Value fastChangeStride = strides[fastChangeDim];
67+ LLVM_DEBUG ({
68+ DBGS () << " fastChangeStride: " ;
69+ fastChangeStride.print (llvm::dbgs ());
70+ llvm::dbgs () << " \n " ;
71+ });
5772 if (!mlir::triton::gpu::intel::isConstant (fastChangeStride, 1 ))
5873 return ;
5974
6075 // Across Intel platforms, the strictest pitch restriction is to be a
6176 // multiple of OWord(128 bits).
6277 Value pitch =
6378 strides[(fastChangeDim == rank - 1 ) ? rank - 2 : rank - 1 ];
79+ LDBG (" Pitch: " << pitch);
6480 if (!ttgi::isDivisible (pitch, 64 / tensorType.getElementTypeBitWidth ()))
6581 return ;
6682
0 commit comments