@@ -49,16 +49,16 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp,
4949 LDBG (" Used by store op? " << isUsedByStoreOp);
5050 LDBG (" Used by block load op? " << isUsedByBlockLoadOp);
5151
52- LDBG (" hasDotDpasEncoding: " << ttgi::hasDotDpasEncoding (tensorType));
5352 LDBG (" hasDpasEncoding: " << ttgi::hasDpasEncoding (tensorType));
54- if (/* ttgi::hasDotDpasEncoding(tensorType) ||*/ isUsedByBlockLoadOp || (isUsedByStoreOp && ttgi::hasDpasEncoding (tensorType))) {
53+ if (isUsedByBlockLoadOp ||
54+ (isUsedByStoreOp && ttgi::hasDpasEncoding (tensorType))) {
5555 LDBG (" Tensor has DPAS layout or is used by load/store op with DPAS layout, "
5656 " skipping removal" );
5757 return false ;
5858 }
5959
60- LDBG (" Marked for removal: tensor doesn't have DPAS layout and is not used "
61- " by load or store op with DPAS layout" );
60+ LDBG (" Marked for removal: make tensor ptr op is not used by block load op or "
61+ " by store op with DPAS layout" );
6262 return true ;
6363}
6464
@@ -689,8 +689,6 @@ class TritonIntelGPURewriteTensorPointerPass
689689 });
690690 };
691691
692-
693- // TODO: this is working, but materialize block pointer needs to
694692 DenseSet<Operation *> tensorPointersToRemove;
695693 mod.walk ([&](Operation *op) {
696694 if (isa<tt::MakeTensorPtrOp>(op)) {
@@ -711,11 +709,12 @@ class TritonIntelGPURewriteTensorPointerPass
711709 LDBG (" Processing op: " << *crtOp);
712710 if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
713711 LDBG (" is load store, should remove?" );
714- if (shouldRemove (
715- makeTensorPtrOp, /* isUsedByStoreOp=*/ isa<tt::StoreOp>(crtOp),
716- /* isBlockLoad=*/
717- isa<tt::LoadOp>(crtOp) && crtOp->hasAttr (
718- ttgi::TritonIntelGPUDialect::getBlockIOAttrName ()))) {
712+ if (shouldRemove (makeTensorPtrOp,
713+ /* isUsedByStoreOp=*/ isa<tt::StoreOp>(crtOp),
714+ /* isBlockLoad=*/
715+ isa<tt::LoadOp>(crtOp) &&
716+ crtOp->hasAttr (ttgi::TritonIntelGPUDialect::
717+ getBlockIOAttrName ()))) {
719718 LDBG (" Removing: " << result);
720719 tensorPointersToRemove.insert (makeTensorPtrOp);
721720 }
@@ -779,8 +778,9 @@ class TritonIntelGPURewriteTensorPointerPass
779778 });
780779
781780 auto markTensorPointerForRemoval =
782- [this , &tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false ,
783- bool isUsedByBlockLoadOrStoreOp = false ) {
781+ [this ,
782+ &tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false ,
783+ bool isUsedByBlockLoadOrStoreOp = false ) {
784784 if (tt::isTensorPointerType (val.getType ())) {
785785 tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
786786 if (tensorPointersToRemove.count (makeTensorPtrOp)) {
0 commit comments