|
| 1 | +#include "intel/include/Analysis/AxisInfo.h" |
| 2 | +#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" |
| 3 | +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" |
| 4 | +#include "mlir/Analysis/SliceAnalysis.h" |
| 5 | +#include "mlir/Support/LLVM.h" |
| 6 | +#include "triton/Dialect/Triton/IR/Utility.h" |
| 7 | +#include "triton/Dialect/TritonGPU/IR/Dialect.h" |
| 8 | +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" |
| 9 | +#include "triton/Tools/StrUtil.h" |
| 10 | +#include "llvm/Support/Debug.h" |
| 11 | + |
| 12 | +#define DEBUG_TYPE "tritonintelgpu-coalesce" |
| 13 | +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 14 | +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 15 | + |
| 16 | +namespace mlir::triton::gpu::intel { |
| 17 | +#define GEN_PASS_DEF_TRITONINTELGPUCOALESCE |
| 18 | +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" |
| 19 | +} // namespace mlir::triton::gpu::intel |
| 20 | + |
| 21 | +using namespace mlir; |
| 22 | +namespace tt = mlir::triton; |
| 23 | +namespace ttgi = mlir::triton::gpu::intel; |
| 24 | + |
| 25 | +namespace { |
| 26 | + |
| 27 | +struct CoalescePass |
| 28 | + : public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> { |
| 29 | + void |
| 30 | + setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis, |
| 31 | + Operation *op, int numWarps, int threadsPerWarp, |
| 32 | + llvm::MapVector<Operation *, Attribute> &layoutMap) { |
| 33 | + Value ptr = getMemAccessPtr(op); |
| 34 | + auto refTensorType = cast<RankedTensorType>(ptr.getType()); |
| 35 | + |
| 36 | + LDBG("Considering op: " << *op); |
| 37 | + LLVM_DEBUG({ |
| 38 | + DBGS() << "axis info of pointer: "; |
| 39 | + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); |
| 40 | + llvm::dbgs() << "\n"; |
| 41 | + }); |
| 42 | + |
| 43 | + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); |
| 44 | + SmallVector<unsigned> order = argSort(contiguity); |
| 45 | + LDBG("order=[" << triton::join(order, ", ") << "]"); |
| 46 | + |
| 47 | + auto matchesShape = [&refTensorType](const Value &val) { |
| 48 | + auto rttType = dyn_cast<RankedTensorType>(val.getType()); |
| 49 | + return rttType && rttType.getShape() == refTensorType.getShape(); |
| 50 | + }; |
| 51 | + |
| 52 | + // The desired divisibility is the maximum divisibility among all dependent |
| 53 | + // pointers which have the same shape and order as `ptr`. |
| 54 | + llvm::SmallSetVector<Operation *, 32> memAccessesSameOrder; |
| 55 | + memAccessesSameOrder.insert(op); |
| 56 | + if (ptr.getDefiningOp()) { |
| 57 | + for (Operation *use : mlir::multiRootGetSlice(op)) { |
| 58 | + Value val = getMemAccessPtr(use); |
| 59 | + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) |
| 60 | + continue; |
| 61 | + auto currOrder = |
| 62 | + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); |
| 63 | + if (order == currOrder) { |
| 64 | + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); |
| 65 | + memAccessesSameOrder.insert(use); |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); |
| 71 | + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); |
| 72 | + |
| 73 | + int numElems = product<int64_t>(shapePerCTA); |
| 74 | + int numThreads = numWarps * threadsPerWarp; |
| 75 | + |
| 76 | + unsigned perThread = |
| 77 | + ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis); |
| 78 | + LDBG("perThread for op: " << perThread); |
| 79 | + |
| 80 | + for (Operation *opSameOrder : memAccessesSameOrder) { |
| 81 | + if (opSameOrder == op) |
| 82 | + continue; |
| 83 | + unsigned currPerThread = |
| 84 | + ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); |
| 85 | + LDBG("perThread for opSameOrder: " << currPerThread); |
| 86 | + perThread = std::max(perThread, currPerThread); |
| 87 | + } |
| 88 | + |
| 89 | + perThread = std::min<int>(perThread, std::max(numElems / numThreads, 1)); |
| 90 | + LDBG("perThread: " << perThread); |
| 91 | + |
| 92 | + if (!dyn_cast<triton::LoadOp>(op)) { |
| 93 | + // For ops that can result in a global memory write, we should enforce |
| 94 | + // that each thread handles at most 128 bits, which is the widest |
| 95 | + // available vectorized store op; otherwise, the store will have "gaps" |
| 96 | + // in the memory write at the warp level, resulting in worse performance. |
| 97 | + // For loads, we can expect that the gaps won't matter due to the L1 |
| 98 | + // cache. |
| 99 | + perThread = std::min<int>(perThread, ttgi::getNumElementsPerThread( |
| 100 | + op, order, axisInfoAnalysis)); |
| 101 | + } |
| 102 | + SmallVector<unsigned> sizePerThread(refTensorType.getRank(), 1); |
| 103 | + sizePerThread[order[0]] = perThread; |
| 104 | + |
| 105 | + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); |
| 106 | + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( |
| 107 | + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, |
| 108 | + threadsPerWarp, CTALayout); |
| 109 | + } |
| 110 | + |
| 111 | + static Type getNewType(Type type, Attribute encoding) { |
| 112 | + RankedTensorType tensorType = cast<RankedTensorType>(type); |
| 113 | + return RankedTensorType::get(tensorType.getShape(), |
| 114 | + tensorType.getElementType(), encoding); |
| 115 | + } |
| 116 | + |
| 117 | + void coalesceOp(Attribute encoding, Operation *op) { |
| 118 | + OpBuilder builder(op); |
| 119 | + // Convert operands |
| 120 | + // For load/store with tensor pointers, we don't have to change the |
| 121 | + // operands' type, we do this by changing the outputs' type of |
| 122 | + // `make_tensor_ptr` |
| 123 | + SmallVector<Value, 4> newArgs; |
| 124 | + for (auto operand : op->getOperands()) { |
| 125 | + auto tensorType = dyn_cast<RankedTensorType>(operand.getType()); |
| 126 | + if (tensorType && |
| 127 | + !isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) { |
| 128 | + Type newType = getNewType(tensorType, encoding); |
| 129 | + newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>( |
| 130 | + op->getLoc(), newType, operand)); |
| 131 | + } else { |
| 132 | + newArgs.push_back(operand); |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + // Convert output types |
| 137 | + SmallVector<Type, 4> newTypes; |
| 138 | + for (auto t : op->getResultTypes()) { |
| 139 | + bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op); |
| 140 | + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); |
| 141 | + } |
| 142 | + |
| 143 | + // Construct new op with the new encoding |
| 144 | + Operation *newOp = |
| 145 | + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, |
| 146 | + newTypes, op->getAttrs()); |
| 147 | + |
| 148 | + // Cast the results back to the original layout |
| 149 | + for (size_t i = 0; i < op->getNumResults(); i++) { |
| 150 | + Value newResult = newOp->getResult(i); |
| 151 | + if (newTypes[i] != op->getResultTypes()[i]) { |
| 152 | + newResult = builder.create<triton::gpu::ConvertLayoutOp>( |
| 153 | + op->getLoc(), op->getResult(i).getType(), newResult); |
| 154 | + } |
| 155 | + op->getResult(i).replaceAllUsesWith(newResult); |
| 156 | + } |
| 157 | + op->erase(); |
| 158 | + } |
| 159 | + |
| 160 | + void runOnOperation() override { |
| 161 | + // Run axis info analysis |
| 162 | + ModuleOp moduleOp = getOperation(); |
| 163 | + tt::intel::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); |
| 164 | + |
| 165 | + // For each i/o operation, we determine what layout |
| 166 | + // the pointers should have for best memory coalescing |
| 167 | + llvm::MapVector<Operation *, Attribute> layoutMap; |
| 168 | + moduleOp.walk([&](Operation *curr) { |
| 169 | + Value ptr = getMemAccessPtr(curr); |
| 170 | + if (!ptr) |
| 171 | + return; |
| 172 | + // We only convert `tensor<tt.ptr<>>` load/store |
| 173 | + bool isPtrTensor = false; |
| 174 | + if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType())) |
| 175 | + isPtrTensor = isa<tt::PointerType>(tensorType.getElementType()); |
| 176 | + if (!isPtrTensor) |
| 177 | + return; |
| 178 | + auto mod = curr->getParentOfType<ModuleOp>(); |
| 179 | + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); |
| 180 | + int threadsPerWarp = |
| 181 | + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); |
| 182 | + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, |
| 183 | + layoutMap); |
| 184 | + }); |
| 185 | + |
| 186 | + // For each memory op that has a layout L1: |
| 187 | + // 1. Create a coalesced memory layout L2 of the pointer operands |
| 188 | + // 2. Convert all operands from layout L1 to layout L2 |
| 189 | + // 3. Create a new memory op that consumes these operands and |
| 190 | + // produces a tensor with layout L2 |
| 191 | + // 4. Convert the output of this new memory op back to L1 |
| 192 | + // 5. Replace all the uses of the original memory op by the new one |
| 193 | + for (auto &kv : layoutMap) { |
| 194 | + coalesceOp(kv.second, kv.first); |
| 195 | + } |
| 196 | + } |
| 197 | +}; |
| 198 | + |
| 199 | +} // namespace |
0 commit comments