@@ -33,29 +33,26 @@ namespace {
3333// / - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434// / - its pitch is not divisible by Qword bitwidth
3535// / - it is not contiguous in memory
36- bool shouldRemove (tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp ,
37- const bool isUsedByBlockLoadOrStoreOp ) {
36+ bool shouldRemove (tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp ,
37+ const bool isUsedByBlockLoadOp ) {
3838 LDBG (" Considering removal of: " << op);
3939 if (!op->getParentOfType <ModuleOp>()->hasAttr (
4040 ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName ())) {
4141 LDBG (" Marked for removal: 2D block operation not supported" );
4242 return true ;
4343 }
4444
45- if (isUsedByBlockLoadOrStoreOp) {
46- LDBG (" Used by block load/store, skipping removal" );
47- return false ;
48- }
49-
5045 auto ptrType = cast<tt::PointerType>(op.getType ());
5146 LDBG (" Op ptr type: " << ptrType);
5247 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
5348 LDBG (" Op tensor type: " << tensorType);
54- LDBG (" Used by load or store op? " << isUsedByLoadOrStoreOp);
49+ LDBG (" Used by store op? " << isUsedByStoreOp);
50+ LDBG (" Used by block load op? " << isUsedByBlockLoadOp);
5551
56- if (ttgi::hasDotDpasEncoding (tensorType) &&
57- (isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding (tensorType))) {
58- LDBG (" Tensor with DPAS layout is used by load/store op with DPAS layout, "
52+ LDBG (" hasDotDpasEncoding: " << ttgi::hasDotDpasEncoding (tensorType));
53+ LDBG (" hasDpasEncoding: " << ttgi::hasDpasEncoding (tensorType));
54+ if (ttgi::hasDotDpasEncoding (tensorType) || isUsedByBlockLoadOp || (isUsedByStoreOp && ttgi::hasDpasEncoding (tensorType))) {
55+ LDBG (" Tensor has DPAS layout or is used by load/store op with DPAS layout, "
5956 " skipping removal" );
6057 return false ;
6158 }
@@ -692,18 +689,9 @@ class TritonIntelGPURewriteTensorPointerPass
692689 });
693690 };
694691
695- auto markTensorPointerForRemoval =
696- [this ](Value val, bool isUsedByLoadOrStoreOp = false ,
697- bool isUsedByBlockLoadOrStoreOp = false ) {
698- if (tt::isTensorPointerType (val.getType ())) {
699- tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
700- if (shouldRemove (makeTensorPtrOp, isUsedByLoadOrStoreOp,
701- isUsedByBlockLoadOrStoreOp)) {
702- valueToRemove.insert (val);
703- }
704- }
705- };
706692
693+ // TODO: this is working, but materialize block pointer needs to
694+ DenseSet<Operation *> tensorPointersToRemove;
707695 mod.walk ([&](Operation *op) {
708696 if (isa<tt::MakeTensorPtrOp>(op)) {
709697 DenseSet<Operation *> workingSet;
@@ -724,12 +712,12 @@ class TritonIntelGPURewriteTensorPointerPass
724712 if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
725713 LDBG (" is load store, should remove?" );
726714 if (shouldRemove (
727- makeTensorPtrOp, /* isUsedByLoadOrStoreOp =*/ true ,
728- /* isBlockLoadOrStore =*/
729- crtOp->hasAttr (
715+ makeTensorPtrOp, /* isUsedByStoreOp =*/ isa<tt::StoreOp>(crtOp) ,
716+ /* isBlockLoad =*/
717+ isa<tt::LoadOp>(crtOp) && crtOp->hasAttr (
730718 ttgi::TritonIntelGPUDialect::getBlockIOAttrName ()))) {
731719 LDBG (" Removing: " << result);
732- valueToRemove .insert (result );
720+ tensorPointersToRemove .insert (makeTensorPtrOp );
733721 }
734722 } else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
735723 for (auto [arg, blockArg] :
@@ -790,6 +778,33 @@ class TritonIntelGPURewriteTensorPointerPass
790778 #endif
791779 });
792780
781+ auto markTensorPointerForRemoval =
782+ [this , &tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false ,
783+ bool isUsedByBlockLoadOrStoreOp = false ) {
784+ if (tt::isTensorPointerType (val.getType ())) {
785+ tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp (val);
786+ if (tensorPointersToRemove.count (makeTensorPtrOp)) {
787+ valueToRemove.insert (val);
788+ }
789+ }
790+ };
791+
792+ mod.walk ([&](Operation *op) {
793+ if (isa<tt::MakeTensorPtrOp>(op)) {
794+ Value result = op->getResult (0 );
795+ markTensorPointerForRemoval (result, usedByLoadOrStoreOp (result));
796+ } else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
797+ markTensorPointerForRemoval (op->getOperand (0 ),
798+ isa<tt::LoadOp, tt::StoreOp>(op));
799+ } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
800+ for (auto arg : forOp.getInitArgs ())
801+ markTensorPointerForRemoval (arg);
802+ } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
803+ for (auto operand : yieldOp.getOperands ())
804+ markTensorPointerForRemoval (operand);
805+ }
806+ });
807+
793808 LLVM_DEBUG ({
794809 if (valueToRemove.empty ())
795810 DBGS () << " No tensor pointer to remove" ;
0 commit comments