1212#include " mlir/IR/Verifier.h"
1313#include " mlir/Support/LLVM.h"
1414#include " triton/Dialect/Triton/IR/Dialect.h"
15+ #include " triton/Dialect/Triton/IR/Types.h"
1516#include " llvm/ADT/TypeSwitch.h"
1617#include " llvm/Support/Debug.h"
1718#include " llvm/Support/ErrorHandling.h"
@@ -373,6 +374,8 @@ struct PtrState {
373374
374375 Value createTTAdvanceOp (Value ptr, tt::MakeTensorPtrOp makeTPtrOp,
375376 OpBuilder &builder, Location loc) const {
377+ assert (triton::isTensorPointerType (ptr.getType ()) &&
378+ " Expecting a block ptr" );
376379 SmallVector<Value> newOffsets;
377380 for (const auto &[offset, stride] :
378381 llvm::zip (offsets, makeTPtrOp.getStrides ()))
@@ -676,44 +679,13 @@ struct TritonRaiseBlockPointer
676679 }
677680
678681 LogicalResult rewriteAddPtrOp (tt::AddPtrOp op) {
679- LLVM_DEBUG (llvm::dbgs () << " Rewriting: " << *op << " \n " );
680-
681682 OpBuilder builder (op);
682683 Location loc = op.getLoc ();
683684 Value ptr = op.getPtr ();
684685
685- auto fillOffsets = [&](Value offset, unsigned rank,
686- SmallVector<Value> &offsets) {
687- switch (rank) {
688- case 1 :
689- offsets.push_back (offset);
690- break ;
691- case 2 :
692- offsets.push_back (
693- findOrCreateConstant (loc, 0 , offsetBitwidth, builder));
694- offsets.push_back (offset);
695- break ;
696- default :
697- llvm_unreachable (" unexpected rank" );
698- }
699- };
700-
701- auto getConstantValue = [](arith::ConstantOp cstOp) {
702- TypedAttr cstVal = cstOp.getValue ();
703- APInt val;
704- if (auto attr = dyn_cast<DenseIntElementsAttr>(cstVal))
705- val = attr.getSplatValue <APInt>();
706- else if (auto attr = dyn_cast<IntegerAttr>(cstVal))
707- val = attr.getValue ();
708- else
709- assert (false && " unexpected constant type" );
710-
711- return val;
712- };
713-
714- // If the ptr has already been mapped (i.e. rewritten into a block
715- // pointer), rewrite the AddPtrOp using and AdvanceOp.
686+ // Case 1: the ptr has been already been mapped.
716687 if (Value mappedV = ptrMap.lookupOrNull (ptr)) {
688+ // Case 1a: the ptr has been mapped to a make_tensor_ptr operation.
717689 if (auto makeTPtrOp = mappedV.getDefiningOp <tt::MakeTensorPtrOp>()) {
718690 PtrState state;
719691 if (failed (visitOperand (op.getOffset (), state, loc, builder)))
@@ -726,20 +698,60 @@ struct TritonRaiseBlockPointer
726698 cleanUp.insert (op);
727699 ptrMap.map (op.getResult (), advanceOp);
728700
729- LLVM_DEBUG (llvm::dbgs ()
730- << " Rewrote:\n\t " << op << " \n to:\n\t " << advanceOp << " \n " );
701+ LLVM_DEBUG ({
702+ auto modOp =
703+ builder.getBlock ()->getParentOp ()->getParentOfType <ModuleOp>();
704+ llvm::dbgs () << " Module:\n " << modOp << " \n " ;
705+ llvm::dbgs () << " Rewrote:\n\t " << op << " \n to:\n\t " << advanceOp
706+ << " \n " ;
707+ });
708+
709+ return success ();
710+ }
711+
712+ // Case 1b: the ptr has been mapped to a tt.advance operation.
713+ if (auto advanceOp = mappedV.getDefiningOp <tt::AdvanceOp>()) {
714+ PtrState state;
715+ if (failed (visitOperand (op.getOffset (), state, loc, builder)))
716+ return failure ();
717+
718+ // Skip through a chain of tt.advance operations...
719+ Value ptr = advanceOp.getPtr ();
720+ while (auto advanceOp = ptr.getDefiningOp <tt::AdvanceOp>())
721+ ptr = advanceOp.getPtr ();
722+
723+ // ... until we find the make_tensor_ptr operation defining the block
724+ // ptr feeding the first tt.advance operation.
725+ auto makeTPtrOp = ptr.getDefiningOp <tt::MakeTensorPtrOp>();
726+ assert (makeTPtrOp && " Expected a MakeTensorPtrOp" );
727+
728+ Value newAdvanceOp = state.createTTAdvanceOp (advanceOp.getResult (),
729+ makeTPtrOp, builder, loc);
730+
731+ cleanUp.insert (op);
732+ ptrMap.map (op.getResult (), newAdvanceOp);
733+
734+ LLVM_DEBUG ({
735+ llvm::dbgs () << " Rewrote:\n\t " << op << " \n to:\n\t " << newAdvanceOp
736+ << " \n " ;
737+ auto modOp =
738+ builder.getBlock ()->getParentOp ()->getParentOfType <ModuleOp>();
739+ llvm::dbgs () << " Module:\n " << modOp << " \n " ;
740+ });
741+
731742 return success ();
732- } else {
733- llvm_unreachable (" Did not find tt::MakeTensorPtrOp" );
734743 }
744+
745+ llvm_unreachable (" Unexpected mappedV defining operation" );
735746 }
736747
748+ // Case 2: the ptr has not previously been mapped.
737749 // If the addptr operation increments a scalar pointer, give up.
738750 Value result = op.getResult ();
739751 if (!isa<RankedTensorType>(result.getType ()))
740752 return failure ();
741753
742- // Otherwise, rewrite the AddPtrOp using PtrState .
754+ // Otherwise, rewrite the AddPtrOp.
743755 PtrState state;
744756 if (failed (visitOperandAddptr (op, state, loc, builder)))
745757 return failure ();
@@ -750,16 +762,11 @@ struct TritonRaiseBlockPointer
750762 Value makePtrOp = state.createTTMakeTensorPtrOp (builder, loc);
751763 knownPtrs[makePtrOp] = std::move (state);
752764
753- ptrMap.map (result, makePtrOp);
754-
755- LLVM_DEBUG (llvm::dbgs ()
756- << " Rewrote:\n\t " << op << " \n to:\n\t " << makePtrOp << " \n " );
757-
758- // AddPtrOps that have been rewritten and no longer used in the code must
759- // be removed in the pass to avoid type matching issue.
760765 cleanUp.insert (op);
766+ ptrMap.map (result, makePtrOp);
761767
762768 LLVM_DEBUG ({
769+ llvm::dbgs () << " Rewrote:\n\t " << op << " \n to:\n\t " << makePtrOp << " \n " ;
763770 auto modOp =
764771 builder.getBlock ()->getParentOp ()->getParentOfType <ModuleOp>();
765772 llvm::dbgs () << " Module:\n " << modOp << " \n " ;
@@ -915,8 +922,8 @@ struct TritonRaiseBlockPointer
915922 }
916923
917924 // This operand must be an iter-arg of an inner-loop in a multiple-level
918- // nested loop, which means its PtrState must have already been populated
919- // during rewriteForOp of the parent loop.
925+ // nested loop, which means its PtrState must have already been
926+ // populated during rewriteForOp of the parent loop.
920927 state = knownPtrs[operand];
921928 return success ();
922929 }
0 commit comments