|
9 | 9 | #ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_ |
10 | 10 | #define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBITIONUTILS_H_ |
11 | 11 |
|
| 12 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
12 | 14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
13 | 15 | #include "mlir/IR/PatternMatch.h" |
| 16 | +#include "mlir/IR/Value.h" |
14 | 17 |
|
| 18 | +#include <numeric> |
15 | 19 | #include <utility> |
16 | 20 |
|
17 | 21 | namespace mlir { |
18 | 22 | namespace gpu { |
19 | | -/// Return a value yielded by `warpOp` which statifies the filter lamdba |
20 | | -/// condition and is not dead. |
21 | | -OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, |
22 | | - const std::function<bool(Operation *)> &fn); |
| 23 | +/// Move scalar operations with no dependency on the warp op outside of the |
| 24 | +/// region. |
| 25 | +void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op); |
23 | 26 |
|
24 | | -/// Helper to create a new WarpExecuteOnLane0Op with different signature. |
25 | | -WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( |
| 27 | +template <typename T> |
| 28 | +struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> { |
| 29 | + using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; |
| 30 | + virtual LogicalResult |
| 31 | + matchAndRewrite(T op, PatternRewriter &rewriter) const override = 0; |
| 32 | + |
| 33 | +protected: |
| 34 | + /// Return a value yielded by `warpOp` which statifies the filter lamdba |
| 35 | + /// condition and is not dead. |
| 36 | + static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, |
| 37 | + const std::function<bool(Operation *)> &fn); |
| 38 | + |
| 39 | + /// Helper to create a new WarpExecuteOnLane0Op with different signature. |
| 40 | + static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( |
| 41 | + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
| 42 | + ValueRange newYieldedValues, TypeRange newReturnTypes); |
| 43 | + |
| 44 | + /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. |
| 45 | + /// `indices` return the index of each new output. |
| 46 | + static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( |
| 47 | + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
| 48 | + ValueRange newYieldedValues, TypeRange newReturnTypes, |
| 49 | + llvm::SmallVector<size_t> &indices); |
| 50 | + |
| 51 | + /// Delinearize the given `laneId` into multiple dimensions, where each |
| 52 | + /// dimension's size is determined by `originalShape` and `distributedShape` |
| 53 | + /// together. This function expects the total numbers of threads needed for |
| 54 | + /// distribution is equal to `warpSize`. Returns true and updates |
| 55 | + /// `delinearizedIds` if so. |
| 56 | + static bool delinearizeLaneId(OpBuilder &builder, Location loc, |
| 57 | + ArrayRef<int64_t> originalShape, |
| 58 | + ArrayRef<int64_t> distributedShape, |
| 59 | + int64_t warpSize, Value laneId, |
| 60 | + SmallVectorImpl<Value> &delinearizedIds); |
| 61 | +}; |
| 62 | + |
| 63 | +template <typename T> |
| 64 | +WarpExecuteOnLane0Op |
| 65 | +WarpDistributionPattern<T>::moveRegionToNewWarpOpAndReplaceReturns( |
26 | 66 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
27 | | - ValueRange newYieldedValues, TypeRange newReturnTypes); |
| 67 | + ValueRange newYieldedValues, TypeRange newReturnTypes) { |
| 68 | + // Create a new op before the existing one, with the extra operands. |
| 69 | + OpBuilder::InsertionGuard g(rewriter); |
| 70 | + rewriter.setInsertionPoint(warpOp); |
| 71 | + auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
| 72 | + warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), |
| 73 | + warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); |
| 74 | + |
| 75 | + Region &opBody = warpOp.getBodyRegion(); |
| 76 | + Region &newOpBody = newWarpOp.getBodyRegion(); |
| 77 | + Block &newOpFirstBlock = newOpBody.front(); |
| 78 | + rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); |
| 79 | + rewriter.eraseBlock(&newOpFirstBlock); |
| 80 | + assert(newWarpOp.getWarpRegion().hasOneBlock() && |
| 81 | + "expected WarpOp with single block"); |
| 82 | + |
| 83 | + auto yield = |
| 84 | + cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); |
| 85 | + |
| 86 | + rewriter.modifyOpInPlace( |
| 87 | + yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); }); |
| 88 | + return newWarpOp; |
| 89 | +} |
28 | 90 |
|
29 | | -/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. |
30 | | -/// `indices` return the index of each new output. |
31 | | -WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( |
| 91 | +template <typename T> |
| 92 | +WarpExecuteOnLane0Op |
| 93 | +WarpDistributionPattern<T>::moveRegionToNewWarpOpAndAppendReturns( |
32 | 94 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
33 | 95 | ValueRange newYieldedValues, TypeRange newReturnTypes, |
34 | | - llvm::SmallVector<size_t> &indices); |
35 | | - |
36 | | -/// Helper to know if an op can be hoisted out of the region. |
37 | | -bool canBeHoisted(Operation *op, function_ref<bool(Value)> definedOutside); |
38 | | - |
39 | | -/// Return a value yielded by `warpOp` which statifies the filter lamdba |
40 | | -/// condition and is not dead. |
41 | | -OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, |
42 | | - const std::function<bool(Operation *)> &fn); |
43 | | - |
44 | | -/// Delinearize the given `laneId` into multiple dimensions, where each |
45 | | -/// dimension's size is determined by `originalShape` and `distributedShape` |
46 | | -/// together. This function expects the total numbers of threads needed for |
47 | | -/// distribution is equal to `warpSize`. Returns true and updates |
48 | | -/// `delinearizedIds` if so. |
49 | | -bool delinearizeLaneId(OpBuilder &builder, Location loc, |
50 | | - ArrayRef<int64_t> originalShape, |
51 | | - ArrayRef<int64_t> distributedShape, int64_t warpSize, |
52 | | - Value laneId, SmallVectorImpl<Value> &delinearizedIds); |
| 96 | + llvm::SmallVector<size_t> &indices) { |
| 97 | + SmallVector<Type> types(warpOp.getResultTypes().begin(), |
| 98 | + warpOp.getResultTypes().end()); |
| 99 | + auto yield = cast<gpu::YieldOp>( |
| 100 | + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 101 | + llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), |
| 102 | + yield.getOperands().end()); |
| 103 | + for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { |
| 104 | + if (yieldValues.insert(std::get<0>(newRet))) { |
| 105 | + types.push_back(std::get<1>(newRet)); |
| 106 | + indices.push_back(yieldValues.size() - 1); |
| 107 | + } else { |
| 108 | + // If the value already exit the region don't create a new output. |
| 109 | + for (auto [idx, yieldOperand] : |
| 110 | + llvm::enumerate(yieldValues.getArrayRef())) { |
| 111 | + if (yieldOperand == std::get<0>(newRet)) { |
| 112 | + indices.push_back(idx); |
| 113 | + break; |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); |
| 119 | + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
| 120 | + rewriter, warpOp, yieldValues.getArrayRef(), types); |
| 121 | + rewriter.replaceOp(warpOp, |
| 122 | + newWarpOp.getResults().take_front(warpOp.getNumResults())); |
| 123 | + return newWarpOp; |
| 124 | +} |
| 125 | + |
| 126 | +template <typename T> |
| 127 | +OpOperand *WarpDistributionPattern<T>::getWarpResult( |
| 128 | + WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) { |
| 129 | + auto yield = cast<gpu::YieldOp>( |
| 130 | + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 131 | + for (OpOperand &yieldOperand : yield->getOpOperands()) { |
| 132 | + Value yieldValues = yieldOperand.get(); |
| 133 | + Operation *definedOp = yieldValues.getDefiningOp(); |
| 134 | + if (definedOp && fn(definedOp)) { |
| 135 | + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) |
| 136 | + return &yieldOperand; |
| 137 | + } |
| 138 | + } |
| 139 | + return {}; |
| 140 | +} |
| 141 | + |
| 142 | +template <typename T> |
| 143 | +bool WarpDistributionPattern<T>::delinearizeLaneId( |
| 144 | + OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape, |
| 145 | + ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId, |
| 146 | + SmallVectorImpl<Value> &delinearizedIds) { |
| 147 | + // If the original shape and the distributed shape is the same, we don't |
| 148 | + // distribute at all--every thread is handling the whole. For such case, we |
| 149 | + // should not rely on lane IDs later. So just return an empty lane ID vector. |
| 150 | + if (originalShape == distributedShape) { |
| 151 | + delinearizedIds.clear(); |
| 152 | + return true; |
| 153 | + } |
| 154 | + |
| 155 | + SmallVector<int64_t> sizes; |
| 156 | + for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { |
| 157 | + if (large % small != 0) |
| 158 | + return false; |
| 159 | + sizes.push_back(large / small); |
| 160 | + } |
| 161 | + if (std::accumulate(sizes.begin(), sizes.end(), 1, |
| 162 | + std::multiplies<int64_t>()) != warpSize) |
| 163 | + return false; |
| 164 | + |
| 165 | + AffineExpr s0, s1; |
| 166 | + bindSymbols(builder.getContext(), s0, s1); |
| 167 | + |
| 168 | + int64_t usedThreads = 1; |
| 169 | + |
| 170 | + Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| 171 | + delinearizedIds.assign(sizes.size(), zero); |
| 172 | + |
| 173 | + for (int i = sizes.size() - 1; i >= 0; --i) { |
| 174 | + usedThreads *= sizes[i]; |
| 175 | + if (usedThreads == warpSize) { |
| 176 | + // We've used up all available threads. Don't need to perform modulo |
| 177 | + // anymore. And we can stop the calculation for further dimensions. |
| 178 | + delinearizedIds[i] = laneId; |
| 179 | + break; |
| 180 | + } |
| 181 | + delinearizedIds[i] = |
| 182 | + affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); |
| 183 | + laneId = affine::makeComposedAffineApply( |
| 184 | + builder, loc, s0.floorDiv(usedThreads), {laneId}); |
| 185 | + } |
| 186 | + return true; |
| 187 | +} |
53 | 188 |
|
54 | 189 | } // namespace gpu |
55 | 190 | } // namespace mlir |
|
0 commit comments