|
1 | 1 | #include "mlir/Dialect/UB/IR/UBOps.h"
|
2 | 2 | #include "mlir/IR/Dominance.h"
|
3 |
| -#include "mlir/IR/ImplicitLocOpBuilder.h" |
4 | 3 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
5 | 4 | #include "triton/Analysis/Utility.h"
|
6 | 5 | #include "triton/Dialect/Triton/IR/Dialect.h"
|
@@ -38,32 +37,6 @@ namespace {
|
38 | 37 | // UTILS
|
39 | 38 | /////////////////////////////
|
40 | 39 |
|
41 |
| -class OpBuilderForStage : public ImplicitLocOpBuilder, |
42 |
| - public OpBuilder::Listener { |
43 |
| -public: |
44 |
| - explicit OpBuilderForStage(Location loc, Operation *op, |
45 |
| - CoarseSchedule &schedule) |
46 |
| - : ImplicitLocOpBuilder(loc, op, this), schedule(schedule) { |
47 |
| - if (auto it = schedule.find(op); it != schedule.end()) |
48 |
| - std::tie(stage, cluster) = it->second; |
49 |
| - } |
50 |
| - |
51 |
| - void setStageCluster(std::pair<int, CoarseSchedule::Cluster> stageCluster) { |
52 |
| - stage = stageCluster.first; |
53 |
| - cluster = stageCluster.second; |
54 |
| - } |
55 |
| - |
56 |
| - void notifyOperationInserted(Operation *op, InsertPoint previous) { |
57 |
| - if (stage && cluster) |
58 |
| - schedule.insert(op, *stage, *cluster); |
59 |
| - } |
60 |
| - |
61 |
| -private: |
62 |
| - std::optional<int> stage; |
63 |
| - std::optional<CoarseSchedule::Cluster> cluster; |
64 |
| - CoarseSchedule &schedule; |
65 |
| -}; |
66 |
| - |
67 | 40 | int getSelfLatencyFromAttr(Operation *op) {
|
68 | 41 | auto module = op->getParentOfType<ModuleOp>();
|
69 | 42 | auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper();
|
@@ -207,17 +180,6 @@ int getDefUseStageDiff(Operation *op, scf::ForOp forOp,
|
207 | 180 | return useStage.value() - defStage;
|
208 | 181 | }
|
209 | 182 |
|
210 |
| -Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter, |
211 |
| - Value modulus, Value zero, Value one, |
212 |
| - Value *outWrapCond = nullptr) { |
213 |
| - Value addOne = builder.create<arith::AddIOp>(loc, counter, one); |
214 |
| - Value outOfRangeCond = builder.create<arith::CmpIOp>( |
215 |
| - loc, arith::CmpIPredicate::sge, addOne, modulus); |
216 |
| - if (outWrapCond) |
217 |
| - *outWrapCond = outOfRangeCond; |
218 |
| - return builder.create<arith::SelectOp>(loc, outOfRangeCond, zero, addOne); |
219 |
| -} |
220 |
| - |
221 | 183 | void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue,
|
222 | 184 | DominanceInfo &domInfo) {
|
223 | 185 | if (newValue == oldValue)
|
@@ -644,132 +606,6 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
|
644 | 606 | return forOp;
|
645 | 607 | }
|
646 | 608 |
|
647 |
| -///////////////////////////// |
648 |
| -// LOWER TMA DESCRIPTORS |
649 |
| -///////////////////////////// |
650 |
| - |
651 |
| -void allocTMABuffers(scf::ForOp forOp, |
652 |
| - llvm::MapVector<Operation *, Value> &tmaBufferMapping, |
653 |
| - int maxStage) { |
654 |
| - IRRewriter rewriter(forOp); |
655 |
| - |
656 |
| - // Create a multi-buffered allocation for each MakeTensorDescOp call in the |
657 |
| - // loop |
658 |
| - forOp.walk([&](tt::MakeTensorDescOp op) { |
659 |
| - // TODO peter: walk to loop yield to find the init value if this is a |
660 |
| - // loop-carried value. That would save us from allocating another buffer |
661 |
| - // just for the init value |
662 |
| - auto loc = op.getLoc(); |
663 |
| - Value alloc = rewriter.create<triton::gpu::GlobalScratchAllocOp>( |
664 |
| - loc, triton::getPointerType(rewriter.getI8Type()), |
665 |
| - maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); |
666 |
| - tmaBufferMapping[op.getOperation()] = alloc; |
667 |
| - }); |
668 |
| -} |
669 |
| - |
670 |
| -Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc, |
671 |
| - Value counter) { |
672 |
| - Value tmaSizeVal = |
673 |
| - builder.create<arith::ConstantIntOp>(loc, ttng::TMA_SIZE_BYTES, 32); |
674 |
| - Value offset = builder.create<arith::MulIOp>(loc, tmaSizeVal, counter); |
675 |
| - return builder.create<triton::AddPtrOp>(loc, alloc.getType(), alloc, offset); |
676 |
| -} |
677 |
| - |
678 |
| -LogicalResult rewriteTMABufferUpdates( |
679 |
| - scf::ForOp forOp, |
680 |
| - const llvm::MapVector<Operation *, Value> &tmaBufferMapping, |
681 |
| - ArrayRef<BlockArgument> tmaCounters, int numBuffers, Value one, Value zero, |
682 |
| - CoarseSchedule &schedule) { |
683 |
| - assert(tmaBufferMapping.size() == tmaCounters.size()); |
684 |
| - |
685 |
| - Value numBuffersVal = mlir::OpBuilder(forOp).create<arith::ConstantIntOp>( |
686 |
| - forOp.getLoc(), numBuffers, 32); |
687 |
| - |
688 |
| - for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { |
689 |
| - auto &[op, alloc] = pair; |
690 |
| - |
691 |
| - // Rewriter MakeTensorDescOp as writing a TMA descriptor |
692 |
| - auto makeDescOp = cast<tt::MakeTensorDescOp>(op); |
693 |
| - |
694 |
| - OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp, schedule); |
695 |
| - |
696 |
| - BlockArgument counter = tmaCounters[iOp]; |
697 |
| - Value nextBuf = |
698 |
| - subviewTMADescriptor(builder, builder.getLoc(), alloc, counter); |
699 |
| - if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) { |
700 |
| - return failure(); |
701 |
| - } |
702 |
| - builder.create<ttng::TensormapFenceproxyAcquireOp>(nextBuf); |
703 |
| - Value nextDesc = builder.create<ttng::ReinterpretTensorDescOp>( |
704 |
| - makeDescOp.getType(), nextBuf); |
705 |
| - |
706 |
| - makeDescOp.getResult().replaceAllUsesWith(nextDesc); |
707 |
| - |
708 |
| - // Increment the buffer index counter |
709 |
| - Value nextCounter = createIncrementModulo( |
710 |
| - builder, builder.getLoc(), counter, numBuffersVal, zero, one); |
711 |
| - |
712 |
| - // If we are in a (potentially nested) if region, propagate the counter |
713 |
| - // up to the main for op body scope |
714 |
| - IRRewriter rewriter(forOp); |
715 |
| - nextCounter = |
716 |
| - sinkValueRedefinition(rewriter, counter, nextCounter, op->getBlock()); |
717 |
| - |
718 |
| - // Finally, rewrite the loop level yield |
719 |
| - auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
720 |
| - forYield.setOperand(counter.getArgNumber() - 1, nextCounter); |
721 |
| - } |
722 |
| - return success(); |
723 |
| -} |
724 |
| - |
725 |
| -scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) { |
726 |
| - llvm::MapVector<Operation *, Value> tmaBufferMapping; |
727 |
| - int maxStage = schedule.getNumStages() - 1; |
728 |
| - for (auto &op : forOp.getBody()->without_terminator()) { |
729 |
| - if (auto wgMmaOp = dyn_cast<ttng::WarpGroupDotOp>(&op)) { |
730 |
| - // Hopper only: Add one more buffer slice if there is a WarpGroupDotOp, |
731 |
| - // as if it will be pipelined, we will effectively make the pipeline |
732 |
| - // one stage longer. |
733 |
| - maxStage += 1; |
734 |
| - break; |
735 |
| - } |
736 |
| - } |
737 |
| - allocTMABuffers(forOp, tmaBufferMapping, maxStage); |
738 |
| - if (tmaBufferMapping.empty()) |
739 |
| - return forOp; |
740 |
| - |
741 |
| - IRRewriter builder(forOp); |
742 |
| - Location loc = forOp.getLoc(); |
743 |
| - Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32); |
744 |
| - Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32); |
745 |
| - SmallVector<Value> newOperands; |
746 |
| - unsigned newOperandIndex = forOp.getBody()->getNumArguments(); |
747 |
| - // Create one counter per TMA buffer. This allows the descriptors to be |
748 |
| - // updated independently without needing to write duplicate of existing tma |
749 |
| - // descriptors. |
750 |
| - unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); |
751 |
| - for (int i = 0; i < tmaBufferMapping.size(); ++i) { |
752 |
| - newOperands.push_back(zero); |
753 |
| - } |
754 |
| - |
755 |
| - forOp = addIterArgsToLoop(builder, forOp, newOperands); |
756 |
| - |
757 |
| - auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments()) |
758 |
| - .slice(tmaCounterArgsStartIdx); |
759 |
| - |
760 |
| - // Update yield op with temporary yield values |
761 |
| - auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
762 |
| - for (unsigned i = 0; i < newOperands.size(); ++i) { |
763 |
| - forYield.getResultsMutable().append(newOperands[i]); |
764 |
| - } |
765 |
| - |
766 |
| - if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters, |
767 |
| - maxStage, one, zero, schedule))) { |
768 |
| - llvm_unreachable("Failed to rewrite TMA ops"); |
769 |
| - } |
770 |
| - return forOp; |
771 |
| -} |
772 |
| - |
773 | 609 | /////////////////////////////
|
774 | 610 | // LOWER MMA
|
775 | 611 | /////////////////////////////
|
|
0 commit comments