@@ -33,7 +33,7 @@ namespace {
3333// / - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434// / - its pitch is not divisible by Qword bitwidth
3535// / - it is not contiguous in memory
36- bool shouldRemove (tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
36+ bool shouldRemove (tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp) {
3737 LDBG (" Considering removal of: " << op);
3838 if (!op->getParentOfType <ModuleOp>()->hasAttr (
3939 ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName ())) {
@@ -52,55 +52,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
5252 " by load or store op with DPAS layout" );
5353 return true ;
5454 }
55-
56- TypedValue<triton::PointerType> base = op.getBase ();
57- Operation::operand_range shape = op.getShape ();
58- unsigned rank = shape.size ();
59- assert (rank > 1 && " Expecting tensor with rank > 1" );
60- Operation::operand_range strides = op.getStrides ();
61- Operation::operand_range offsets = op.getOffsets ();
62- ArrayRef<int32_t > order = op.getOrder ();
63- ArrayRef<int64_t > tensorShape = tensorType.getShape ();
64-
65- int fastChangeDim = -1 ;
66- for (size_t i = 0 ; i < strides.size (); ++i) {
67- if (ttgi::isConstant (strides[i], 1 )) {
68- fastChangeDim = i;
69- break ;
70- }
71- }
72-
73- LDBG (" fastChangeDim: " << fastChangeDim);
74- if (fastChangeDim < 0 ) {
75- LDBG (" Marked for removal: fast changing dimension not found" );
76- return true ;
77- }
78-
79- LDBG (" Tensor type element type bit width: "
80- << tensorType.getElementTypeBitWidth ());
81- if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth () == 8 ) {
82- // TODO: column major layout w/ fp8 has performance regression
83- LDBG (" Marked for removal: column major layout with fp8 element type" );
84- return true ;
85- }
86-
87- // HW 2D block read instruction has restriction on pitch divisibility
88- if (fastChangeDim >= (rank - 2 )) {
89- auto pitch = strides[(fastChangeDim == rank - 1 ) ? rank - 2 : rank - 1 ];
90- LDBG (" Pitch: " << pitch);
91- // Across Intel platforms, the strictest pitch restriction is to be a
92- // multiple of OWord(128 bits).
93- if (!ttgi::isDivisible (pitch, 128 / tensorType.getElementTypeBitWidth ())) {
94- LDBG (" Marked for removal: cannot use block read/write instructions" );
95- return true ;
96- }
97-
98- return false ;
99- }
100-
101- LDBG (" Marked for removal: fall-trough" );
102-
103- return true ;
55+ return false ;
10456}
10557
10658// / The `RewritedInfo` struct is used to store information about a rewritten
@@ -715,10 +667,19 @@ class TritonIntelGPURewriteTensorPointerPass
715667 void runOnOperation () override {
716668 ModuleOp mod = getOperation ();
717669
718- auto usedByLoadOrStoreOp = [](Value val) {
719- return llvm::any_of (val.getUsers (), [](Operation *user) {
720- return isa<tt::LoadOp, tt::StoreOp>(user);
721- });
670+ // TODO: do we need this attribute?
671+ auto usedByLoadOrStoreOp = [](Value val,
672+ const bool check_block_io_attribute = false ) {
673+ return llvm::any_of (
674+ val.getUsers (), [check_block_io_attribute](Operation *user) {
675+ const bool is_load_or_store = isa<tt::LoadOp, tt::StoreOp>(user);
676+ if (check_block_io_attribute) {
677+ return user->hasAttr (
678+ ttgi::TritonIntelGPUDialect::getBlockIOAttrName ());
679+ } else {
680+ return is_load_or_store;
681+ }
682+ });
722683 };
723684
724685 auto markTensorPointerForRemoval =
@@ -738,8 +699,17 @@ class TritonIntelGPURewriteTensorPointerPass
738699 markTensorPointerForRemoval (op->getOperand (0 ),
739700 isa<tt::LoadOp, tt::StoreOp>(op));
740701 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
741- for (auto arg : forOp.getInitArgs ())
742- markTensorPointerForRemoval (arg);
702+ for (auto [arg, blockArg] :
703+ llvm::zip (forOp.getInitArgs (),
704+ forOp.getBody ()->getArguments ().drop_front (
705+ forOp.getNumInductionVars ()))) {
706+ if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp ())) {
707+ constexpr bool check_block_io_attribute = true ;
708+ markTensorPointerForRemoval (
709+ arg.getDefiningOp ()->getResult (0 ),
710+ usedByLoadOrStoreOp (blockArg, check_block_io_attribute));
711+ }
712+ }
743713 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
744714 for (auto operand : yieldOp.getOperands ())
745715 markTensorPointerForRemoval (operand);
0 commit comments