11#include " mlir/Dialect/Arith/IR/Arith.h"
22#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
33#include " mlir/Dialect/SCF/IR/SCF.h"
4+ #include " mlir/IR/Attributes.h"
5+ #include " mlir/IR/Block.h"
46#include " mlir/IR/BuiltinAttributes.h"
57#include " mlir/IR/BuiltinOps.h"
68#include " mlir/IR/BuiltinTypes.h"
79#include " mlir/IR/IRMapping.h"
810#include " mlir/IR/Matchers.h"
11+ #include " mlir/IR/OperationSupport.h"
912#include " mlir/IR/PatternMatch.h"
1013#include " mlir/IR/TypeUtilities.h"
1114#include " mlir/IR/Value.h"
1518#include " triton/Dialect/Triton/IR/Types.h"
1619#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1720#include " llvm/ADT/STLExtras.h"
21+ #include " llvm/ADT/SmallVector.h"
22+ #include " llvm/ADT/StringRef.h"
1823#include " llvm/ADT/TypeSwitch.h"
1924#include " llvm/Support/Casting.h"
2025#include " llvm/Support/Debug.h"
@@ -85,6 +90,8 @@ class PointerCanonicalizer {
8590 Value offset;
8691 // Flag to express if we can narrow the uses of the offset down to 32 bits
8792 bool canNarrow = false ;
93+ // Collection of attributes that need to be applied to the pointer
94+ SmallVector<NamedAttribute> attributes;
8895
8996 // Utility copy functions
9097 FatPtr copy (Value newBasePtr, Value newOffset) {
@@ -96,6 +103,11 @@ class PointerCanonicalizer {
96103 FatPtr copyWithOffset (Value newBase) {
97104 return FatPtr{newBase, offset, canNarrow};
98105 }
106+ // Attribute functions
107+ void setAttr (NamedAttribute attr) { attributes.push_back (attr); }
108+ void setAttrs (ArrayRef<NamedAttribute> attrs) {
109+ llvm::append_range (attributes, attrs);
110+ }
99111 };
100112
101113 // Rewrite any operation that needs a pointer
@@ -104,8 +116,15 @@ class PointerCanonicalizer {
104116 // Start from an argument of a function and propagate its fat pointers
105117 LogicalResult rewritePointer (Value argPtr);
106118
119+ // Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is
120+ // obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape
121+ // and adding the offset to it.
107122 Value createTensorPointer (FatPtr fatPtr, Location loc);
108123
124+ // Push the attributes of the given operation `op` to the fat pointer
125+ // corresponding to `val`
126+ void collectFatPointerAttributes (Operation *op, Value val);
127+
109128 // Rewrite a given function, canonicalizing the different pointer arguments of
110129 // the region
111130 LogicalResult rewriteFunction (triton::FuncOp funcOp);
@@ -269,6 +288,46 @@ Value createTensorZero(IRRewriter &rw, Location loc, RankedTensorType type) {
269288
270289} // namespace
271290
291+ void PointerCanonicalizer::collectFatPointerAttributes (Operation *op,
292+ Value val) {
293+ auto addBlockArgumentAttr = [&](BlockArgument arg) {
294+ // If the value is a block parameter, the operation can specify
295+ // an attribute for the given parameter by using `tt.property_argi`
296+ // where `argi` refers to the arg number of the given parameter.
297+ // So we need to iterate through the property, find the right one
298+ // and push the property onto the pointers attributes.
299+ llvm::SmallString<8 > scratchStr;
300+ for (NamedAttribute namedAttr : op->getAttrs ()) {
301+ scratchStr.clear ();
302+ llvm::raw_svector_ostream sstream (scratchStr);
303+ sstream << " _arg" << arg.getArgNumber ();
304+ StringRef attrName = namedAttr.getName ().getValue ();
305+ if (attrName.ends_with (scratchStr)) {
306+ StringRef newAttrName = attrName.drop_back (scratchStr.size ());
307+ namedAttr.setName (rewriter.getStringAttr (newAttrName));
308+ pointers[val].setAttr (namedAttr);
309+ // Propagate the argument to the offset if it is also a block argument
310+ if (auto offsetArg = dyn_cast<BlockArgument>(pointers[val].offset )) {
311+ scratchStr.clear ();
312+ sstream << newAttrName << " _arg" << offsetArg.getArgNumber ();
313+ op->setAttr (scratchStr, namedAttr.getValue ());
314+ }
315+ }
316+ }
317+ };
318+
319+ // If it is the i-th block argument, then look if the operation defined some
320+ // _argi attribute and add it to the fat pointer attributes
321+ if (auto arg = dyn_cast<BlockArgument>(val)) {
322+ addBlockArgumentAttr (arg);
323+ return ;
324+ }
325+
326+ // Otherwise add the attributes of the operation to the fat pointer
327+ for (NamedAttribute attr : op->getAttrs ())
328+ pointers[val].setAttr (attr);
329+ }
330+
272331// Offset extraction logic for an addition op:
273332// decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)}
274333std::pair<Value, Value>
@@ -372,9 +431,6 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
372431 return offsets;
373432}
374433
375- // Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is
376- // obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape
377- // and adding the offset to it.
378434Value PointerCanonicalizer::createTensorPointer (FatPtr fatPtr, Location loc) {
379435 Value basePtr = fatPtr.basePtr ;
380436 Value offset = fatPtr.offset ;
@@ -390,9 +446,12 @@ Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) {
390446 Value tensorPtr =
391447 rewriter.create <triton::SplatOp>(loc, tensorPtrType, basePtr);
392448
393- tensorPtr =
449+ auto addPtrOp =
394450 rewriter.create <triton::AddPtrOp>(loc, tensorPtrType, tensorPtr, offset);
395- return tensorPtr;
451+
452+ addPtrOp->setAttrs (fatPtr.attributes );
453+
454+ return addPtrOp.getResult ();
396455}
397456
398457// Rewrite a memory operation
@@ -477,6 +536,9 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
477536 newPtr = rewriter.create <triton::AddPtrOp>(curLoc, newPtr.getType (), newPtr,
478537 scalarConst);
479538 pointers[nextPtr] = fatPtr.copyWithOffset (newPtr);
539+ // If we are updating the tensor pointer with a uniform value, we can
540+ // propagate the attributes of the tensor pointer to the fat pointer.
541+ pointers[nextPtr].setAttrs (fatPtr.attributes );
480542 opToDelete.insert (addPtrOp);
481543 return success ();
482544 }
@@ -496,6 +558,7 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
496558 Value fatPtrOffset = fatPtr.offset ;
497559 bool canNarrow = fatPtr.canNarrow ;
498560 Value newOffset = fatPtrOffset;
561+ bool propagateAtrs = true ;
499562 if (!isZeroConst (nonUniformOffset)) {
500563 Type addPtrOffsetType = getElementTypeOrSelf (nonUniformOffset);
501564 canNarrow = canNarrow && canNarrowOffset (fatPtrOffset, nonUniformOffset);
@@ -507,9 +570,15 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
507570
508571 newOffset =
509572 rewriter.create <arith::AddIOp>(curLoc, nonUniformOffset, fatPtrOffset);
573+ propagateAtrs = false ;
510574 }
511575 opToDelete.insert (addPtrOp);
512576 pointers[nextPtr] = FatPtr{newPtr, newOffset, canNarrow};
577+
578+ // If we are updating the tensor pointer with a uniform value, we can
579+ // propagate the attributes of the tensor pointer to the fat pointer.
580+ if (propagateAtrs)
581+ pointers[nextPtr].setAttrs (fatPtr.attributes );
513582 return success ();
514583}
515584
@@ -537,9 +606,12 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,
537606 // This is making sure we visit the uses within the forOp region
538607 Value arg = newForOp.getTiedLoopRegionIterArg (forOperand);
539608 size_t numIterArgs = newForOp.getNumRegionIterArgs ();
540- pointers[arg] =
541- FatPtr{newForOp.getRegionIterArg (numIterArgs - 2 ),
542- newForOp.getRegionIterArg (numIterArgs - 1 ), fatPtr.canNarrow };
609+ pointers[arg] = fatPtr.copy (newForOp.getRegionIterArg (numIterArgs - 2 ),
610+ newForOp.getRegionIterArg (numIterArgs - 1 ));
611+
612+ // Collect attributes before continuing the visit
613+ collectFatPointerAttributes (newForOp, arg);
614+
543615 for (OpOperand &use : arg.getUses ())
544616 queue.push_back (&use);
545617
@@ -548,7 +620,6 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,
548620 size_t numResults = newForOp->getNumResults ();
549621 pointers[nextPtr] = fatPtr.copy (newForOp->getResult (numResults - 2 ),
550622 newForOp.getResult (numResults - 1 ));
551-
552623 opToDelete.insert (forOp);
553624 return success ();
554625}
@@ -864,11 +935,13 @@ LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) {
864935 res = materializeFatPointer (op, curLoc, curOperand->get ());
865936 });
866937
867- // Keep propagating the fat pointer down the IR
868- if (nextPtr)
938+ // Collect the attributes and Keep propagating the fat pointer down the IR
939+ if (nextPtr) {
940+ collectFatPointerAttributes (curOp, nextPtr);
869941 for (OpOperand &use : nextPtr.getUses ())
870942 if (!opToDelete.contains (use.getOwner ()))
871943 queue.push_back (&use);
944+ }
872945 }
873946 return success ();
874947}
0 commit comments