|
| 1 | +#include "CodePartitionUtility.h" |
| 2 | +#include "mlir/Analysis/SliceAnalysis.h" |
| 3 | +#include "mlir/Pass/Pass.h" |
| 4 | +#include "mlir/Pass/PassManager.h" |
| 5 | +#include "mlir/Transforms/Passes.h" |
| 6 | +#include "nvidia/hopper/include/Transforms/Passes.h" |
| 7 | +#include <list> |
| 8 | +#include <unordered_set> |
| 9 | + |
| 10 | +namespace tt = mlir::triton; |
| 11 | +namespace ttg = mlir::triton::gpu; |
| 12 | +namespace ttng = ::mlir::triton::nvidia_gpu; |
| 13 | +namespace mlir { |
| 14 | + |
| 15 | +#define DEBUG_TYPE "nvgpu-ws-utility" |
| 16 | +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 17 | +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 18 | + |
| 19 | +// Check to see if op is enclosed under ifOp. |
| 20 | +bool enclosing(scf::IfOp ifOp, Operation *op) { |
| 21 | + return ifOp->isProperAncestor(op); |
| 22 | +} |
| 23 | + |
| 24 | +bool enclosing(scf::ForOp forOp, Operation *op) { |
| 25 | + return forOp->isProperAncestor(op); |
| 26 | +} |
| 27 | + |
| 28 | +// Check to see if there is no outer loop that is enclosed under ifOp. |
| 29 | +bool immediateEnclosing(scf::IfOp ifOp, Operation *subOp) { |
| 30 | + auto pOp = subOp->getParentOfType<scf::ForOp>(); |
| 31 | + if (!pOp) |
| 32 | + return true; |
| 33 | + return !enclosing(ifOp, pOp.getOperation()); |
| 34 | +} |
| 35 | + |
| 36 | +// Return number of AccumCnts for the given ctrlOp. We need one for each nested |
| 37 | +// region that contains a channel. |
| 38 | +unsigned getAccumCnts(Operation *ctrlOp, |
| 39 | + const DenseSet<Operation *> ®ionsWithChannels) { |
| 40 | + unsigned cnt = 0; |
| 41 | + LDBG("getAccumCnts: " << ctrlOp); |
| 42 | + for (auto *op : regionsWithChannels) { |
| 43 | + LDBG("-- getAccumCnts: " << ctrlOp << " regionsWithChannels " << op); |
| 44 | + if (ctrlOp == op) { |
| 45 | + ++cnt; |
| 46 | + continue; |
| 47 | + } |
| 48 | + if (auto forOp = dyn_cast<scf::ForOp>(ctrlOp)) { |
| 49 | + if (enclosing(forOp, op)) |
| 50 | + ++cnt; |
| 51 | + continue; |
| 52 | + } |
| 53 | + if (auto ifOp = dyn_cast<scf::IfOp>(ctrlOp)) { |
| 54 | + if (enclosing(ifOp, op)) |
| 55 | + ++cnt; |
| 56 | + continue; |
| 57 | + } |
| 58 | + llvm_unreachable("region op other than If/For is not supported"); |
| 59 | + } |
| 60 | + return cnt; |
| 61 | +} |
| 62 | + |
| 63 | +// Assume parentForOp has accumCnt for the specified ctrlOp. |
| 64 | +unsigned getAccumArgIdx(scf::ForOp parentForOp, Operation *ctrlOp, |
| 65 | + const DenseSet<Operation *> ®ionsWithChannels) { |
| 66 | + // Walk parentForOp in preorder. |
| 67 | + unsigned preOrderId = 0, ctrlId = 0; |
| 68 | + bool found = false; |
| 69 | + parentForOp->walk<WalkOrder::PreOrder>([&](Operation *subOp) { |
| 70 | + // This will walk parentForOp. |
| 71 | + if (subOp == ctrlOp) { |
| 72 | + ctrlId = preOrderId; |
| 73 | + found = true; |
| 74 | + } |
| 75 | + for (auto *op : regionsWithChannels) { |
| 76 | + if (op == subOp) { |
| 77 | + LDBG("getAccumArgIdx: saw ctrlOp enclosing channel " << subOp); |
| 78 | + ++preOrderId; |
| 79 | + } |
| 80 | + } |
| 81 | + }); |
| 82 | + assert(found && "error in getAccumArgIdx"); |
| 83 | + LDBG("getAccumArgIdx: " << parentForOp.getOperation() << " " << ctrlOp << " " |
| 84 | + << ctrlId); |
| 85 | + return ctrlId; |
| 86 | +} |
| 87 | + |
| 88 | +// Compute and return the buffer index and phase for a given accumulate count. |
| 89 | +std::pair<Value, Value> getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder, |
| 90 | + Location loc, Value accumCnt, |
| 91 | + unsigned numBuffers) { |
| 92 | + Value numBuffersVal = |
| 93 | + builder.createWithAsyncTaskIds<arith::ConstantIntOp>(loc, numBuffers, 32); |
| 94 | + numBuffersVal = builder.createWithAsyncTaskIds<arith::ExtSIOp>( |
| 95 | + loc, builder.getI64Type(), numBuffersVal); |
| 96 | + // Calculate accumCnt / numBuffers |
| 97 | + // initBufferIdx = accumCnt - accumCnt / numBuffers * numBuffers |
| 98 | + // initPhase = (accumCnt / numBuffers) & 1 |
| 99 | + Value bufferIdx = builder.createWithAsyncTaskIds<arith::DivUIOp>( |
| 100 | + loc, accumCnt, numBuffersVal); |
| 101 | + Value initBufferIdx = builder.createWithAsyncTaskIds<arith::SubIOp>( |
| 102 | + loc, accumCnt, |
| 103 | + builder.createWithAsyncTaskIds<arith::MulIOp>(loc, bufferIdx, |
| 104 | + numBuffersVal)); |
| 105 | + initBufferIdx = builder.createWithAsyncTaskIds<arith::TruncIOp>( |
| 106 | + loc, builder.getI32Type(), initBufferIdx); |
| 107 | + |
| 108 | + Value one = builder.createWithAsyncTaskIds<arith::ConstantIntOp>(loc, 1, 64); |
| 109 | + bufferIdx = |
| 110 | + builder.createWithAsyncTaskIds<arith::AndIOp>(loc, bufferIdx, one); |
| 111 | + Value initPhase = builder.createWithAsyncTaskIds<arith::TruncIOp>( |
| 112 | + loc, builder.getI1Type(), bufferIdx); |
| 113 | + return {initBufferIdx, initPhase}; |
| 114 | +} |
| 115 | + |
| 116 | +// Get the current accumulation count for the given op within its immediate |
| 117 | +// scope. |
| 118 | +// ForA (accumForA, accumIfA, accumForB, accumIfB) |
| 119 | +// IfA (accumIfA, accumForB) |
| 120 | +// Channel A --> uses ForA.arg[accumIfA] |
| 121 | +// ForB (accumForB) |
| 122 | +// Channel B --> uses ForB.arg[accumForB] |
| 123 | +// ThenYield ForA.arg[accumIfA] + 1, ForB.res[accumForB] |
| 124 | +// ElseYield ForA.arg[accumIfA], ForA.arg[accumForB] |
| 125 | +// ForC (accumForC, accumIfB) |
| 126 | +// IfB |
| 127 | +// Channel C --> uses ForC.arg[accumIfB] |
| 128 | +// ThenYield ForC.arg[accumIfB] + 1 |
| 129 | +// ElseYield ForC.arg[accumIfB] |
| 130 | +// Channel D --> uses ForA.arg[accumForA] |
| 131 | +Value getAccumCount(OpBuilderWithAsyncTaskIds &builder, Operation *op, |
| 132 | + const DenseSet<Operation *> ®ionsWithChannels) { |
| 133 | + auto parentForOp = op->getParentOfType<scf::ForOp>(); |
| 134 | + auto *pOp = op->getParentOp(); |
| 135 | + // Get parentForOp.arg[pOp] |
| 136 | + unsigned tSize = parentForOp.getBody()->getArguments().size(); |
| 137 | + unsigned parentTCnts = getAccumCnts(parentForOp, regionsWithChannels); |
| 138 | + unsigned accumArgId = getAccumArgIdx(parentForOp, pOp, regionsWithChannels); |
| 139 | + Value accumCnt = |
| 140 | + parentForOp.getBody()->getArgument(tSize - parentTCnts + accumArgId); |
| 141 | + |
| 142 | + LDBG("getAccumCount: parentForOp " << parentForOp.getOperation() << " pOp " |
| 143 | + << pOp << " " << tSize << " " |
| 144 | + << parentTCnts << " " << accumArgId); |
| 145 | + return accumCnt; |
| 146 | +} |
| 147 | + |
| 148 | +void getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder, Operation *op, |
| 149 | + unsigned numBuffers, |
| 150 | + const DenseSet<Operation *> ®ionsWithChannels, |
| 151 | + Value &bufferIdx, Value &phase) { |
| 152 | + Value accumCnt = getAccumCount(builder, op, regionsWithChannels); |
| 153 | + std::tie(bufferIdx, phase) = |
| 154 | + getBufferIdxAndPhase(builder, op->getLoc(), accumCnt, numBuffers); |
| 155 | +} |
| 156 | + |
| 157 | +Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, |
| 158 | + Value barrierAlloc, Value bufferIdx) { |
| 159 | + auto context = barrierAlloc.getContext(); |
| 160 | + Attribute sharedMemorySpace = |
| 161 | + triton::gpu::SharedMemorySpaceAttr::get(context); |
| 162 | + ttg::MemDescType barrierTy = ttg::MemDescType::get( |
| 163 | + {1}, builder.getI64Type(), |
| 164 | + cast<ttg::MemDescType>(barrierAlloc.getType()).getEncoding(), |
| 165 | + sharedMemorySpace, |
| 166 | + /*mutableMemory=*/true); |
| 167 | + |
| 168 | + // Create barrierForTMA from barrierAlloc. |
| 169 | + return builder.createWithAsyncTaskIds<ttg::MemDescSubviewOp>( |
| 170 | + barrierAlloc.getLoc(), barrierTy, barrierAlloc, |
| 171 | + ArrayRef<Value>({bufferIdx})); |
| 172 | +} |
| 173 | + |
| 174 | +} // namespace mlir |
0 commit comments