|
| 1 | +// Copyright 2024 The IREE Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +#include "iree/compiler/Codegen/Common/Passes.h" |
| 8 | +#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" |
| 9 | +#include "iree/compiler/Codegen/Transforms/Transforms.h" |
| 10 | +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| 11 | +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" |
| 12 | +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| 13 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 14 | +#include "mlir/Dialect/Arith/Utils/Utils.h" |
| 15 | +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 16 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | + |
| 19 | +#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions" |
| 20 | + |
| 21 | +namespace mlir::iree_compiler { |
| 22 | + |
| 23 | +#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS |
| 24 | +#include "iree/compiler/Codegen/Common/Passes.h.inc" |
| 25 | + |
| 26 | +using TensorDivisibilityInfo = |
| 27 | + llvm::SmallDenseMap<unsigned, IREE::Util::ConstantIntDivisibility>; |
| 28 | + |
| 29 | +namespace { |
| 30 | + |
| 31 | +struct RemoveOptimizationBarrier final |
| 32 | + : public OpRewritePattern<IREE::Util::OptimizationBarrierOp> { |
| 33 | + using OpRewritePattern::OpRewritePattern; |
| 34 | + |
| 35 | + LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, |
| 36 | + PatternRewriter &rewriter) const override { |
| 37 | + rewriter.replaceOp(barrierOp, barrierOp.getOperands()); |
| 38 | + return success(); |
| 39 | + } |
| 40 | +}; |
| 41 | + |
| 42 | +/// This pass is used to materialize information about dynamic dimensions of |
| 43 | +/// `tensor` operands of an operation in the IR. If a dynamic dimension is |
| 44 | +/// known to be a multiple of a compile-time constant value, this pass |
| 45 | +/// expands the shape of the operands. For example if a `tensor` operand |
| 46 | +/// is of shape `tensor<...x?x...>` and that dimension is known to be a |
| 47 | +/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the |
| 48 | +/// size of the new dynamic dimension is 1/16-th the size of the original |
| 49 | +/// dynamic dimension size. This is done in two steps. |
| 50 | +/// 1) Replace operands with such dynamic dimension with the result of a |
| 51 | +/// `tensor.expand_shape/tensor.collapse_shape` pair |
| 52 | +/// to materialize the new static dimension and immediately fold it away. A |
| 53 | +/// optimization barrier is added in between to prevent these operations from |
| 54 | +/// being folded. |
| 55 | +/// 2) Use patterns that propagate the `tensor.collapse_shape` down to |
| 56 | +/// manipulate the operation appropriately. This |
| 57 | +/// allows re-using the (fairly complex) logic used to expand dimensions of |
| 58 | +/// operations implemented in the propagation patterns. |
| 59 | +/// At the end of the pass the optimization barriers are removed to fold away |
| 60 | +/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns. |
| 61 | +struct BlockDynamicDimensionsPass final |
| 62 | + : impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> { |
| 63 | + void runOnOperation() override; |
| 64 | +}; |
| 65 | +} // namespace |
| 66 | + |
| 67 | +/// Retrieve the divisibility information for dynamic dimensions of `v` if |
| 68 | +/// known. |
| 69 | +static TensorDivisibilityInfo |
| 70 | +getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, |
| 71 | + Value v) { |
| 72 | + TensorDivisibilityInfo divisibilityInfo; |
| 73 | + auto tensorType = dyn_cast<RankedTensorType>(v.getType()); |
| 74 | + if (!tensorType) { |
| 75 | + return divisibilityInfo; |
| 76 | + } |
| 77 | + |
| 78 | + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { |
| 79 | + if (!tensorType.isDynamicDim(index)) |
| 80 | + continue; |
| 81 | + std::optional<IREE::Util::ConstantIntDivisibility> dimDivisibility = |
| 82 | + dynamicDimAnalysis.getDivisibilityInfo(v, index); |
| 83 | + if (!dimDivisibility) |
| 84 | + continue; |
| 85 | + divisibilityInfo[index] = std::move(dimDivisibility.value()); |
| 86 | + } |
| 87 | + |
| 88 | + return divisibilityInfo; |
| 89 | +} |
| 90 | + |
| 91 | +/// For a `v` if the dimension is known to be multiple of a compile-time static |
| 92 | +/// value, insert |
| 93 | +/// |
| 94 | +/// ```mlir |
| 95 | +/// %v_expand = tensor.expand_shape %v |
| 96 | +/// %barrier = util.optimization.barrier %v |
| 97 | +/// %v_collapse = tensor.collapse_shape %barrier |
| 98 | +/// ``` |
| 99 | +/// |
| 100 | +/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are |
| 101 | +/// inverses of each other. The `util.optimization.barrier` avoid these from |
| 102 | +/// getting folded away during reshape propagation. Return the result of the |
| 103 | +/// `tensor.collapse_shape generated. |
| 104 | +static std::optional<Value> |
| 105 | +blockDynamicDimensionsOfValue(RewriterBase &rewriter, |
| 106 | + const TensorDivisibilityInfo &divisibilityInfo, |
| 107 | + Value v) { |
| 108 | + auto tensorType = dyn_cast<RankedTensorType>(v.getType()); |
| 109 | + if (!tensorType) { |
| 110 | + return std::nullopt; |
| 111 | + } |
| 112 | + |
| 113 | + // Check if we know that the operands have a divisibility information. |
| 114 | + SmallVector<OpFoldResult> outputShape; |
| 115 | + SmallVector<ReassociationIndices> reassociation; |
| 116 | + Location loc = v.getLoc(); |
| 117 | + |
| 118 | + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { |
| 119 | + reassociation.emplace_back(ReassociationIndices{}); |
| 120 | + |
| 121 | + // Check if this needs division. |
| 122 | + if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) { |
| 123 | + reassociation.back().push_back(outputShape.size()); |
| 124 | + outputShape.push_back(rewriter.getIndexAttr(dim)); |
| 125 | + continue; |
| 126 | + } |
| 127 | + |
| 128 | + // Split the dynamic based on the divisibility info. |
| 129 | + IREE::Util::ConstantIntDivisibility currDivisibility = |
| 130 | + divisibilityInfo.lookup(index); |
| 131 | + uint64_t factor = currDivisibility.sdiv(); |
| 132 | + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); |
| 133 | + AffineExpr divExpr = s0.floorDiv(factor); |
| 134 | + Value sourceDim = rewriter.create<tensor::DimOp>(loc, v, index).getResult(); |
| 135 | + OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply( |
| 136 | + rewriter, loc, divExpr, ArrayRef<OpFoldResult>{sourceDim}); |
| 137 | + OpFoldResult newStaticDim = rewriter.getIndexAttr(factor); |
| 138 | + |
| 139 | + reassociation.back().push_back(outputShape.size()); |
| 140 | + reassociation.back().push_back(outputShape.size() + 1); |
| 141 | + |
| 142 | + outputShape.push_back(newDynamicDim); |
| 143 | + outputShape.push_back(newStaticDim); |
| 144 | + } |
| 145 | + |
| 146 | + auto staticOutputShape = |
| 147 | + llvm::map_to_vector(outputShape, [](OpFoldResult ofr) { |
| 148 | + if (auto staticShapeAttr = dyn_cast<Attribute>(ofr)) { |
| 149 | + return cast<IntegerAttr>(staticShapeAttr).getInt(); |
| 150 | + } |
| 151 | + return ShapedType::kDynamic; |
| 152 | + }); |
| 153 | + auto outputType = RankedTensorType::get( |
| 154 | + staticOutputShape, tensorType.getElementType(), tensorType.getEncoding()); |
| 155 | + |
| 156 | + Value expandShape = rewriter.create<tensor::ExpandShapeOp>( |
| 157 | + loc, outputType, v, reassociation, outputShape); |
| 158 | + Value barrier = |
| 159 | + rewriter.create<IREE::Util::OptimizationBarrierOp>(loc, expandShape) |
| 160 | + .getResult(0); |
| 161 | + Value collapseShape = rewriter.create<tensor::CollapseShapeOp>( |
| 162 | + loc, tensorType, barrier, reassociation); |
| 163 | + return collapseShape; |
| 164 | +} |
| 165 | + |
| 166 | +/// For an operation, replace the operands at indices specified in |
| 167 | +/// `limitToOperandIndices` with the result of |
| 168 | +/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the |
| 169 | +/// information about dynamic dimensions that are known to be a multiple of a |
| 170 | +/// compile-time static value. For example, |
| 171 | +/// |
| 172 | +/// ```mlir |
| 173 | +/// %1 = <some_op>(..., %0, ...) : ... , tensor<4x?x6xf32> |
| 174 | +/// ``` |
| 175 | +/// |
| 176 | +/// If the dynamic dimension is known to be a multiple of 16, then generate |
| 177 | +/// |
| 178 | +/// ```mlir |
| 179 | +/// %expanded = tensor.expand_shape %0 : |
| 180 | +/// tensor<4x?x5xf32> into tensor<4x?x16x6xf32> |
| 181 | +/// %barrier = util.optimization.barrier %expanded |
| 182 | +/// %collapsed = tensor.collapse_shape %barrier |
| 183 | +/// : tensor<4x?x16x5xf32> into tensor<4x?x5xf32> |
| 184 | +/// %1 = <some_op>(..., %collaped, ...) : ... , tensor<4x?x6xf32> |
| 185 | +/// ``` |
| 186 | +static LogicalResult blockDynamicDimensions( |
| 187 | + RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, |
| 188 | + Operation *operation, llvm::SmallDenseSet<int64_t> limitToOperandIndices) { |
| 189 | + OpBuilder::InsertionGuard g(rewriter); |
| 190 | + |
| 191 | + for (OpOperand &operand : operation->getOpOperands()) { |
| 192 | + if (!limitToOperandIndices.contains(operand.getOperandNumber())) |
| 193 | + continue; |
| 194 | + if (operand.get().getDefiningOp<tensor::CollapseShapeOp>()) |
| 195 | + continue; |
| 196 | + TensorDivisibilityInfo operandDivisibilityInfo = |
| 197 | + getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); |
| 198 | + if (operandDivisibilityInfo.empty()) |
| 199 | + continue; |
| 200 | + std::optional<Value> newOperand = blockDynamicDimensionsOfValue( |
| 201 | + rewriter, operandDivisibilityInfo, operand.get()); |
| 202 | + if (newOperand) { |
| 203 | + rewriter.modifyOpInPlace(operation, |
| 204 | + [&]() { operand.set(newOperand.value()); }); |
| 205 | + } |
| 206 | + } |
| 207 | + return success(); |
| 208 | +} |
| 209 | + |
| 210 | +/// Insert `tensor.expand_shape` operations to materialize in IR information |
| 211 | +/// about dynamic dimensions that are known to be a multiple of a compile-time |
| 212 | +/// know value, for the operands of `iree_linalg_ext.attention` operation. |
| 213 | +static LogicalResult |
| 214 | +blockDynamicDimensions(RewriterBase &rewriter, |
| 215 | + const TensorDynamicDimAnalysis &dynamicDimAnalysis, |
| 216 | + IREE::LinalgExt::AttentionOp attentionOp) { |
| 217 | + // Only block the q and k values. |
| 218 | + llvm::SmallDenseSet<int64_t> prunedOperandsList; |
| 219 | + prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber()); |
| 220 | + prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber()); |
| 221 | + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp, |
| 222 | + prunedOperandsList); |
| 223 | +} |
| 224 | + |
| 225 | +void BlockDynamicDimensionsPass::runOnOperation() { |
| 226 | + Operation *operation = getOperation(); |
| 227 | + MLIRContext *context = &getContext(); |
| 228 | + TensorDynamicDimAnalysis dynamicDimAnalysis(operation); |
| 229 | + if (failed(dynamicDimAnalysis.run())) { |
| 230 | + return signalPassFailure(); |
| 231 | + } |
| 232 | + |
| 233 | + IRRewriter rewriter(context); |
| 234 | + auto walkResult = operation->walk( |
| 235 | + [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult { |
| 236 | + rewriter.setInsertionPoint(attentionOp); |
| 237 | + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, |
| 238 | + attentionOp); |
| 239 | + }); |
| 240 | + if (walkResult.wasInterrupted()) { |
| 241 | + return signalPassFailure(); |
| 242 | + } |
| 243 | + |
| 244 | + LLVM_DEBUG({ |
| 245 | + llvm::dbgs() << "After blocking dimensions:\n"; |
| 246 | + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| 247 | + llvm::dbgs() << "\n"; |
| 248 | + }); |
| 249 | + |
| 250 | + { |
| 251 | + RewritePatternSet bubbleExpandShapePatterns(context); |
| 252 | + // Add patterns to "push down" the `tensor.collapse_shape` patterns (which |
| 253 | + // are the dual of the patterns to "bubble up" `tensor.expand_shape` |
| 254 | + // patterns) |
| 255 | + linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; }; |
| 256 | + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, |
| 257 | + controlFn); |
| 258 | + IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( |
| 259 | + bubbleExpandShapePatterns, controlFn); |
| 260 | + // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and |
| 261 | + // "pushed-down" `tensor.collapse_shape` operation with their interface |
| 262 | + // bindings or `tensor.empty` operations. |
| 263 | + populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); |
| 264 | + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); |
| 265 | + // Add some additional patterns that can simplify the IR and remove dead |
| 266 | + // operations. |
| 267 | + memref::populateResolveRankedShapedTypeResultDimsPatterns( |
| 268 | + bubbleExpandShapePatterns); |
| 269 | + populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns); |
| 270 | + if (failed(applyPatternsAndFoldGreedily( |
| 271 | + operation, std::move(bubbleExpandShapePatterns)))) { |
| 272 | + operation->emitOpError( |
| 273 | + "failed in application of bubble up expand shape patterns"); |
| 274 | + return signalPassFailure(); |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + LLVM_DEBUG({ |
| 279 | + llvm::dbgs() << "After reshape propagation:\n"; |
| 280 | + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| 281 | + llvm::dbgs() << "\n"; |
| 282 | + }); |
| 283 | + |
| 284 | + // Delete the optimization barrier and run some further cleanup. |
| 285 | + { |
| 286 | + RewritePatternSet removeBarrierOpsPatterns(context); |
| 287 | + removeBarrierOpsPatterns.insert<RemoveOptimizationBarrier>(context); |
| 288 | + tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, |
| 289 | + context); |
| 290 | + tensor::CollapseShapeOp::getCanonicalizationPatterns( |
| 291 | + removeBarrierOpsPatterns, context); |
| 292 | + if (failed(applyPatternsAndFoldGreedily( |
| 293 | + operation, std::move(removeBarrierOpsPatterns)))) { |
| 294 | + operation->emitOpError("failed in cleanup patterns"); |
| 295 | + return signalPassFailure(); |
| 296 | + } |
| 297 | + } |
| 298 | + |
| 299 | + return; |
| 300 | +} |
| 301 | + |
| 302 | +} // namespace mlir::iree_compiler |
0 commit comments