|
| 1 | +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" |
| 2 | +#include "amd/lib/TritonAMDGPUTransforms/Utility.h" |
| 3 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 4 | +#include "llvm/ADT/TypeSwitch.h" |
| 5 | + |
| 6 | +#define GEN_PASS_CLASSES |
| 7 | +#include "TritonAMDGPUTransforms/Passes.h" |
| 8 | + |
| 9 | +// This pass updates the waitCount of `AsyncWait` Ops to represent the number of |
| 10 | +// inflight async load operation between the async_wait and the definition of |
| 11 | +// the AsyncToken, thus allowing to wait only on the dependent async loads |
| 12 | +// allowing loads issued after to complete in the future. |
| 13 | +// This also means we should never overestimate the value to ensure |
| 14 | +// correctness; being conservative and underestimating is fine given that only |
| 15 | +// affects performance |
| 16 | +// For each async_wait we need to compute the minimum across all AsyncToken |
| 17 | +// operands. |
| 18 | +// For each token the minimum number of async transaction along it's |
| 19 | +// def chain is deduced. A token can be copied when passing in as loop initial |
| 20 | +// argument and yielded from a loop body in which case we need to take the |
| 21 | +// minimum along both paths. |
| 22 | +// We do not exit early if we encounter another async_wait along the def chain |
| 23 | +// because the pipeliner will merge redundant waits for us already |
| 24 | + |
| 25 | +using namespace mlir; |
| 26 | +namespace tt = triton; |
| 27 | +namespace ttg = triton::gpu; |
| 28 | + |
| 29 | +// Returns the number of individual async load memory transactions when copy |
| 30 | +// data from the given |srcTy| in global memory to the given |dstTy| in shared |
| 31 | +// memory. |
| 32 | +int getNumberOfLoadInstructions(RankedTensorType srcTy, |
| 33 | + ttg::MemDescType dstTy) { |
| 34 | + auto shape = srcTy.getShape(); |
| 35 | + LinearLayout srcLayout = tt::gpu::toLinearLayout(shape, srcTy.getEncoding()); |
| 36 | + LinearLayout sharedLayout = |
| 37 | + tt::gpu::toLinearLayout(shape, dstTy.getEncoding()); |
| 38 | + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); |
| 39 | + |
| 40 | + // On GFX9 we cannot split direct to lds loads into multiple ones because we |
| 41 | + // need coalesced writes. So we can divide the number of registers by the |
| 42 | + // contiguity to get the number of load instructions. |
| 43 | + int contig = srcToSharedLayout.getNumConsecutiveInOut(); |
| 44 | + int numberOfRegisters = srcToSharedLayout.getInDimSize( |
| 45 | + StringAttr::get(srcTy.getContext(), "register")); |
| 46 | + int loadInstructionCount = std::max(1, numberOfRegisters / contig); |
| 47 | + return loadInstructionCount; |
| 48 | +} |
| 49 | + |
| 50 | +// The pipeliner always insert ops following an order of ttg.async_load -> |
| 51 | +// [token] -> ttg.async_commit_group -> [token] -> ttg.async_wait. So here we |
| 52 | +// scan the operands of ttg.async_commit_group to count the number of issued |
| 53 | +// async load intrinsics. |
| 54 | +int getNumberOfLoadInstructions(Operation *op) { |
| 55 | + if (isa<ttg::AsyncCommitGroupOp>(op)) { |
| 56 | + int count = 0; |
| 57 | + for (auto token : op->getOperands()) { |
| 58 | + auto defOp = token.getDefiningOp(); |
| 59 | + if (!defOp) |
| 60 | + continue; |
| 61 | + if (auto copyOp = llvm::dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(defOp)) { |
| 62 | + count += getNumberOfLoadInstructions(copyOp.getSrc().getType(), |
| 63 | + copyOp.getResult().getType()); |
| 64 | + } else if (auto copyOp = |
| 65 | + llvm::dyn_cast<amdgpu::BufferLoadToLocalOp>(defOp)) { |
| 66 | + auto srcTy = cast<RankedTensorType>(LLVM::AMD::getPointerTypeWithShape( |
| 67 | + copyOp.getPtr(), copyOp.getOffsets())); |
| 68 | + count += getNumberOfLoadInstructions(srcTy, copyOp.getDest().getType()); |
| 69 | + } |
| 70 | + } |
| 71 | + return count; |
| 72 | + } |
| 73 | + if (isa<tt::LoadOp, tt::StoreOp, amdgpu::BufferLoadToLocalOp, |
| 74 | + amdgpu::BufferStoreOp, tt::AtomicRMWOp, tt::AtomicCASOp, |
| 75 | + amdgpu::BufferAtomicRMWOp>(op)) { |
| 76 | + op->emitRemark("Global memory operation between async wait and " |
| 77 | + "async_loads. This will hinder the interleaving of memory " |
| 78 | + "operations and might impact performance."); |
| 79 | + } |
| 80 | + return 0; |
| 81 | +} |
| 82 | + |
| 83 | +// LLVM cannot infer the dependency between direct to lds (async) loads and |
| 84 | +// the local reads between warps in a workgroup. As a workaround we update the |
| 85 | +// waitcnt to represent the number of hardware instructions we are |
| 86 | +// interleaving with. This allows us to manually emit the waitcnt during |
| 87 | +// lowering. |
| 88 | +void updateWaitCount(ttg::AsyncWaitOp waitOp, RewriterBase &rewriter) { |
| 89 | + int waitCnt = std::numeric_limits<int>::max(); |
| 90 | + |
| 91 | + // AsyncWait can await multiple tokens so we get the minimum from all |
| 92 | + // tokens |
| 93 | + for (auto token : waitOp.getOperands()) { |
| 94 | + // Traverse def chain from waitOp to the producer of the token and count |
| 95 | + // the minumum number of vmcnt instructions |
| 96 | + auto tokenWaitCnt = |
| 97 | + deduceMinCountOnDefChain(token, waitOp, [](Operation *op) { |
| 98 | + return getNumberOfLoadInstructions(op); |
| 99 | + }); |
| 100 | + waitCnt = std::min(waitCnt, tokenWaitCnt); |
| 101 | + } |
| 102 | + |
| 103 | + if (waitCnt == std::numeric_limits<int>::max() || waitOp.getNum() == waitCnt) |
| 104 | + return; |
| 105 | + |
| 106 | + rewriter.modifyOpInPlace(waitOp, [&]() { waitOp.setNum(waitCnt); }); |
| 107 | +} |
| 108 | + |
| 109 | +struct TritonAMDGPUUpdateAsyncWaitCountPass |
| 110 | + : public TritonAMDGPUUpdateAsyncWaitCountBase< |
| 111 | + TritonAMDGPUUpdateAsyncWaitCountPass> { |
| 112 | + TritonAMDGPUUpdateAsyncWaitCountPass(StringRef archGenName) { |
| 113 | + this->archGenerationName = archGenName.str(); |
| 114 | + } |
| 115 | + |
| 116 | + void runOnOperation() override { |
| 117 | + tt::AMD::TargetInfo targetInfo(archGenerationName); |
| 118 | + if (!targetInfo.isCDNA()) { |
| 119 | + return; |
| 120 | + } |
| 121 | + |
| 122 | + ModuleOp m = getOperation(); |
| 123 | + |
| 124 | + SmallVector<ttg::AsyncWaitOp> waitOps; |
| 125 | + getOperation()->walk( |
| 126 | + [&](ttg::AsyncWaitOp waitOp) { waitOps.push_back(waitOp); }); |
| 127 | + |
| 128 | + for (auto waitOp : waitOps) { |
| 129 | + IRRewriter builder(waitOp->getContext()); |
| 130 | + updateWaitCount(waitOp, builder); |
| 131 | + } |
| 132 | + } |
| 133 | +}; |
| 134 | + |
| 135 | +std::unique_ptr<Pass> |
| 136 | +mlir::createTritonAMDGPUUpdateAsyncWaitCountPass(std::string archGenName) { |
| 137 | + return std::make_unique<TritonAMDGPUUpdateAsyncWaitCountPass>(archGenName); |
| 138 | +} |
0 commit comments