@@ -33,26 +33,36 @@ namespace {
3333// / - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434// / - its pitch is not divisible by Qword bitwidth
3535// / - it is not contiguous in memory
36- bool shouldRemove (tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp) {
36+ bool shouldRemove (tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp,
37+ const bool isUsedByBlockLoadOrStoreOp) {
3738 LDBG (" Considering removal of: " << op);
3839 if (!op->getParentOfType <ModuleOp>()->hasAttr (
3940 ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName ())) {
4041 LDBG (" Marked for removal: 2D block operation not supported" );
4142 return true ;
4243 }
4344
45+ if (isUsedByBlockLoadOrStoreOp) {
46+ LDBG (" Used by block load/store, skipping removal" );
47+ return false ;
48+ }
49+
4450 auto ptrType = cast<tt::PointerType>(op.getType ());
4551 LDBG (" Op ptr type: " << ptrType);
4652 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
4753 LDBG (" Op tensor type: " << tensorType);
54+ LDBG (" Used by load or store op? " << isUsedByLoadOrStoreOp);
4855
49- if (! ttgi::hasDotDpasEncoding (tensorType) &&
50- ! (isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding (tensorType))) {
51- LDBG (" Marked for removal: tensor doesn't have DPAS layout and is not used "
52- " by load or store op with DPAS layout " );
53- return true ;
56+ if (ttgi::hasDotDpasEncoding (tensorType) &&
57+ (isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding (tensorType))) {
58+ LDBG (" Tensor with DPAS layout is used by load/store op with DPAS layout, "
59+ " skipping removal " );
60+ return false ;
5461 }
55- return false ;
62+
63+ LDBG (" Marked for removal: tensor doesn't have DPAS layout and is not used "
64+ " by load or store op with DPAS layout" );
65+ return true ;
5666}
5767
5868// / The `RewritedInfo` struct is used to store information about a rewritten
@@ -683,37 +693,101 @@ class TritonIntelGPURewriteTensorPointerPass
683693 };
684694
685695 auto markTensorPointerForRemoval =
686- [this ](Value val, bool isUsedByLoadOrStoreOp = false ) {
696+ [this ](Value val, bool isUsedByLoadOrStoreOp = false ,
697+ bool isUsedByBlockLoadOrStoreOp = false ) {
687698 if (tt::isTensorPointerType (val.getType ())) {
688699 tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
689- if (shouldRemove (makeTensorPtrOp, isUsedByLoadOrStoreOp))
700+ if (shouldRemove (makeTensorPtrOp, isUsedByLoadOrStoreOp,
701+ isUsedByBlockLoadOrStoreOp)) {
690702 valueToRemove.insert (val);
703+ }
691704 }
692705 };
693706
694707 mod.walk ([&](Operation *op) {
695708 if (isa<tt::MakeTensorPtrOp>(op)) {
709+ DenseSet<Operation *> workingSet;
710+
711+ auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op);
712+ LDBG (" Considering: " << *op);
696713 Value result = op->getResult (0 );
697- markTensorPointerForRemoval (result, usedByLoadOrStoreOp (result));
714+ for (auto user : result.getUsers ()) {
715+ workingSet.insert (user); // TODO: safe? need to check ptr?
716+ }
717+ while (!workingSet.empty ()) {
718+ for (auto v : workingSet) {
719+ LDBG (" Working set val: " << *v);
720+ }
721+ auto crtOpItr = workingSet.begin ();
722+ auto crtOp = *crtOpItr;
723+ LDBG (" Processing op: " << *crtOp);
724+ if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
725+ LDBG (" is load store, should remove?" );
726+ if (shouldRemove (
727+ makeTensorPtrOp, /* isUsedByLoadOrStoreOp=*/ true ,
728+ /* isBlockLoadOrStore=*/
729+ crtOp->hasAttr (
730+ ttgi::TritonIntelGPUDialect::getBlockIOAttrName ()))) {
731+ LDBG (" Removing: " << result);
732+ valueToRemove.insert (result);
733+ }
734+ } else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
735+ for (auto [arg, blockArg] :
736+ llvm::zip (forOp.getInitArgs (),
737+ forOp.getBody ()->getArguments ().drop_front (
738+ forOp.getNumInductionVars ()))) {
739+ if (arg == makeTensorPtrOp) {
740+ // add users of block arg
741+ for (auto user : blockArg.getUsers ()) {
742+ workingSet.insert (user);
743+ }
744+ }
745+ }
746+ #if 0
747+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
748+ for (auto operand : yieldOp.getOperands()) {
749+ workingSet.insert(operand->getResult(0));
750+ }
751+ #endif
752+ } else if (crtOp->getNumResults () > 0 ) {
753+ // TODO: handle more than one result?
754+ auto crtOpResult = crtOp->getResult (0 );
755+ LDBG (" Not a load store and not a loop, adding users to working "
756+ " set." );
757+ for (auto user : crtOpResult.getUsers ()) {
758+ workingSet.insert (user);
759+ }
760+ }
761+ workingSet.erase (crtOpItr);
762+ }
763+ #if 1
764+ }
765+ #else
698766 } else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
699- markTensorPointerForRemoval (op->getOperand (0 ),
700- isa<tt::LoadOp, tt::StoreOp>(op));
767+ const bool isLoadStoreOp = isa<tt::LoadOp, tt::StoreOp>(op);
768+ markTensorPointerForRemoval(
769+ op->getOperand(0), isLoadStoreOp,
770+ isLoadStoreOp &&
771+ op->hasAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName()));
701772 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
702773 for (auto [arg, blockArg] :
703774 llvm::zip(forOp.getInitArgs(),
704775 forOp.getBody()->getArguments().drop_front(
705776 forOp.getNumInductionVars()))) {
777+ LDBG("arg: " << arg);
706778 if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp())) {
707779 constexpr bool check_block_io_attribute = true;
708780 markTensorPointerForRemoval(
709781 arg.getDefiningOp()->getResult(0),
782+ usedByLoadOrStoreOp(blockArg),
710783 usedByLoadOrStoreOp(blockArg, check_block_io_attribute));
711784 }
712785 }
713786 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
714787 for (auto operand : yieldOp.getOperands())
715788 markTensorPointerForRemoval(operand);
716789 }
790+ #endif
717791 });
718792
719793 LLVM_DEBUG ({
@@ -722,7 +796,7 @@ class TritonIntelGPURewriteTensorPointerPass
722796 else {
723797 DBGS () << " Values to remove: " ;
724798 for (auto val : valueToRemove)
725- DBGS () << val;
799+ DBGS () << val << " \n " ;
726800 }
727801 });
728802
@@ -746,6 +820,7 @@ class TritonIntelGPURewriteTensorPointerPass
746820 valueToRemove.clear ();
747821 while (!eraser.empty ()) {
748822 auto op = eraser.top ();
823+ LDBG (" DELETING " << *op);
749824 eraser.pop ();
750825 op->erase ();
751826 }
0 commit comments