1717#include " mlir/Transforms/DialectConversion.h"
1818#include " triton/Analysis/Utility.h"
1919#include " triton/Dialect/Triton/IR/Dialect.h"
20+ #include " triton/Dialect/Triton/IR/DiscardableAttributes.h"
2021#include " triton/Dialect/Triton/IR/Types.h"
2122#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
2223#include " llvm/ADT/STLExtras.h"
@@ -561,12 +562,20 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
561562 RewriterBase::InsertionGuard guard (rewriter);
562563 rewriter.setInsertionPoint (addPtrOp);
563564
565+ // Query all discardable attributes that we want to preserve
566+ std::array<StringRef, 3 > propagateList{" tt.divisibility" , " tt.contiguity" ,
567+ " tt.constancy" };
568+ SmallVector<NamedAttribute> propagatedAttrs =
569+ tt::filterDiscardableAttrs (addPtrOp.getOperation (), propagateList);
570+
564571 // If it is a scalar pointer update, simply bump the base pointer
565572 if (llvm::isa<tt::PointerType>(addPtrOp.getPtr ().getType ())) {
566573 assert (llvm::isa<IntegerType>(origOffset.getType ()) &&
567574 " expected offset to be integer type" );
568575 auto newAddPtrOp = rewriter.create <tt::AddPtrOp>(
569576 curLoc, fatPtrBase.getType (), fatPtrBase, origOffset);
577+ newAddPtrOp->setDiscardableAttrs (propagatedAttrs);
578+
570579 rewriter.replaceOpWithMultiple (addPtrOp, {{newAddPtrOp, fatPtrOffset}});
571580 fatPtrs[{newAddPtrOp, fatPtrOffset}] =
572581 fatPtrs.at ({fatPtrBase, fatPtrOffset});
@@ -581,6 +590,8 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
581590 maybeGetOrCreateScalarConstant (rewriter, curLoc, origOffset)) {
582591 tt::AddPtrOp newAddPtrOp = rewriter.create <tt::AddPtrOp>(
583592 curLoc, fatPtrBase.getType (), fatPtrBase, *scalarConst);
593+ newAddPtrOp->setDiscardableAttrs (propagatedAttrs);
594+
584595 rewriter.replaceOpWithMultiple (addPtrOp, {{newAddPtrOp, fatPtrOffset}});
585596 // If we are updating the tensor pointer with a constant value, we can
586597 // propagate the attributes of the tensor pointer to the fat pointer.
@@ -596,6 +607,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
596607
597608 auto newAddPtrOp = rewriter.create <tt::AddPtrOp>(
598609 curLoc, fatPtrBase.getType (), fatPtrBase, uniformOffset);
610+ newAddPtrOp->setDiscardableAttrs (propagatedAttrs);
599611
600612 // Vector offset update (if any): bump the tensor offset
601613 bool canNarrow = fatPtrs.at ({fatPtrBase, fatPtrOffset}).canNarrow ;
0 commit comments