@@ -674,21 +674,6 @@ class TritonIntelGPURewriteTensorPointerPass
674674 void runOnOperation () override {
675675 ModuleOp mod = getOperation ();
676676
677- // TODO: do we need this attribute?
678- auto usedByLoadOrStoreOp = [](Value val,
679- const bool check_block_io_attribute = false ) {
680- return llvm::any_of (
681- val.getUsers (), [check_block_io_attribute](Operation *user) {
682- const bool is_load_or_store = isa<tt::LoadOp, tt::StoreOp>(user);
683- if (check_block_io_attribute) {
684- return user->hasAttr (
685- ttgi::TritonIntelGPUDialect::getBlockIOAttrName ());
686- } else {
687- return is_load_or_store;
688- }
689- });
690- };
691-
692677 DenseSet<Operation *> tensorPointersToRemove;
693678 mod.walk ([&](Operation *op) {
694679 if (isa<tt::MakeTensorPtrOp>(op)) {
@@ -698,24 +683,22 @@ class TritonIntelGPURewriteTensorPointerPass
698683 LDBG (" Considering: " << *op);
699684 Value result = op->getResult (0 );
700685 for (auto user : result.getUsers ()) {
701- workingSet.insert (user); // TODO: safe? need to check ptr?
686+ workingSet.insert (user);
702687 }
703688 while (!workingSet.empty ()) {
704- for (auto v : workingSet) {
705- LDBG (" Working set val: " << *v);
706- }
707689 auto crtOpItr = workingSet.begin ();
708690 auto crtOp = *crtOpItr;
709691 LDBG (" Processing op: " << *crtOp);
710692 if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
711- LDBG (" is load store, should remove?" );
693+ LDBG (" is load store, checking to see if we should remove make "
694+ " tensor ptr op" );
712695 if (shouldRemove (makeTensorPtrOp,
713696 /* isUsedByStoreOp=*/ isa<tt::StoreOp>(crtOp),
714697 /* isBlockLoad=*/
715698 isa<tt::LoadOp>(crtOp) &&
716699 crtOp->hasAttr (ttgi::TritonIntelGPUDialect::
717700 getBlockIOAttrName ()))) {
718- LDBG (" Removing : " << result);
701+ LDBG (" Marking op for removal : " << result);
719702 tensorPointersToRemove.insert (makeTensorPtrOp);
720703 }
721704 } else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
@@ -730,14 +713,8 @@ class TritonIntelGPURewriteTensorPointerPass
730713 }
731714 }
732715 }
733- #if 0
734- } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
735- for (auto operand : yieldOp.getOperands()) {
736- workingSet.insert(operand->getResult(0));
737- }
738- #endif
739716 } else if (crtOp->getNumResults () > 0 ) {
740- // TODO: handle more than one result?
717+ // TODO: should we handle more than one result?
741718 auto crtOpResult = crtOp->getResult (0 );
742719 LDBG (" Not a load store and not a loop, adding users to working "
743720 " set." );
@@ -747,55 +724,25 @@ class TritonIntelGPURewriteTensorPointerPass
747724 }
748725 workingSet.erase (crtOpItr);
749726 }
750- #if 1
751- }
752- #else
753- } else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
754- const bool isLoadStoreOp = isa<tt::LoadOp, tt::StoreOp>(op);
755- markTensorPointerForRemoval(
756- op->getOperand(0), isLoadStoreOp,
757- isLoadStoreOp &&
758- op->hasAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName()));
759- } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
760- for (auto [arg, blockArg] :
761- llvm::zip(forOp.getInitArgs(),
762- forOp.getBody()->getArguments().drop_front(
763- forOp.getNumInductionVars()))) {
764- LDBG("arg: " << arg);
765- if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp())) {
766- constexpr bool check_block_io_attribute = true;
767- markTensorPointerForRemoval(
768- arg.getDefiningOp()->getResult(0),
769- usedByLoadOrStoreOp(blockArg),
770- usedByLoadOrStoreOp(blockArg, check_block_io_attribute));
771- }
772- }
773- } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
774- for (auto operand : yieldOp.getOperands())
775- markTensorPointerForRemoval(operand);
776727 }
777- #endif
778728 });
779729
780- auto markTensorPointerForRemoval =
781- [this ,
782- &tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false ,
783- bool isUsedByBlockLoadOrStoreOp = false ) {
784- if (tt::isTensorPointerType (val.getType ())) {
785- tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
786- if (tensorPointersToRemove.count (makeTensorPtrOp)) {
787- valueToRemove.insert (val);
788- }
789- }
790- };
730+ auto markTensorPointerForRemoval = [this ,
731+ &tensorPointersToRemove](Value val) {
732+ if (tt::isTensorPointerType (val.getType ())) {
733+ tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
734+ if (tensorPointersToRemove.count (makeTensorPtrOp)) {
735+ valueToRemove.insert (val);
736+ }
737+ }
738+ };
791739
792740 mod.walk ([&](Operation *op) {
793741 if (isa<tt::MakeTensorPtrOp>(op)) {
794742 Value result = op->getResult (0 );
795- markTensorPointerForRemoval (result, usedByLoadOrStoreOp (result) );
743+ markTensorPointerForRemoval (result);
796744 } else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
797- markTensorPointerForRemoval (op->getOperand (0 ),
798- isa<tt::LoadOp, tt::StoreOp>(op));
745+ markTensorPointerForRemoval (op->getOperand (0 ));
799746 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
800747 for (auto arg : forOp.getInitArgs ())
801748 markTensorPointerForRemoval (arg);
0 commit comments