|
| 1 | +// Copyright 2025 The IREE Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" |
| 8 | +#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" |
| 9 | +#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" |
| 10 | +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" |
| 11 | +#include "iree/compiler/DispatchCreation/Passes.h" |
| 12 | +#include "llvm/ADT/STLExtras.h" |
| 13 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 14 | +#include "mlir/IR/MLIRContext.h" |
| 15 | +#include "mlir/IR/PatternMatch.h" |
| 16 | +#include "mlir/Interfaces/FunctionInterfaces.h" |
| 17 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | + |
| 19 | +#define DEBUG_TYPE "iree-dispatch-creation-propagate-encodings" |
| 20 | + |
| 21 | +namespace mlir::iree_compiler::DispatchCreation { |
| 22 | + |
| 23 | +#define GEN_PASS_DEF_PROPAGATEENCODINGSPASS |
| 24 | +#include "iree/compiler/DispatchCreation/Passes.h.inc" |
| 25 | + |
| 26 | +namespace { |
| 27 | + |
| 28 | +/// Pattern to swap `tensor.collapse_shape` -> `iree_encoding.set_encoding` |
| 29 | +struct SwapEncodingOpWithTensorCollapseShapeOp |
| 30 | + : public OpRewritePattern<IREE::Encoding::SetEncodingOp> { |
| 31 | + using Base = OpRewritePattern<IREE::Encoding::SetEncodingOp>; |
| 32 | + using Base::Base; |
| 33 | + LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, |
| 34 | + PatternRewriter &rewriter) const override; |
| 35 | +}; |
| 36 | + |
| 37 | +// TODO(#20179): Support the propagation through interfaces. It is supposed to |
| 38 | +// be done with data-flow analysis. |
| 39 | +struct PropagateEncodingsPass |
| 40 | + : public DispatchCreation::impl::PropagateEncodingsPassBase< |
| 41 | + PropagateEncodingsPass> { |
| 42 | + void runOnOperation() override; |
| 43 | +}; |
| 44 | + |
| 45 | +} // namespace |
| 46 | + |
| 47 | +LogicalResult SwapEncodingOpWithTensorCollapseShapeOp::matchAndRewrite( |
| 48 | + IREE::Encoding::SetEncodingOp encodingOp, PatternRewriter &rewriter) const { |
| 49 | + auto encoding = dyn_cast<IREE::Encoding::MatmulKAttr>( |
| 50 | + encodingOp.getResultType().getEncoding()); |
| 51 | + if (!encoding) { |
| 52 | + return rewriter.notifyMatchFailure(encodingOp, "only matmul_k is handled"); |
| 53 | + } |
| 54 | + auto collapseOp = |
| 55 | + encodingOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
| 56 | + if (!collapseOp) { |
| 57 | + return rewriter.notifyMatchFailure(encodingOp, |
| 58 | + "expected a collapse_shape producer"); |
| 59 | + } |
| 60 | + if (!IREE::Flow::isNonNullAndOutsideDispatch(encodingOp) || |
| 61 | + !IREE::Flow::isNonNullAndOutsideDispatch(collapseOp)) { |
| 62 | + return rewriter.notifyMatchFailure( |
| 63 | + encodingOp, "expected that both operations are outside dispatch"); |
| 64 | + } |
| 65 | + |
| 66 | + ArrayRef<int32_t> kDims = encoding.getKDims().asArrayRef(); |
| 67 | + llvm::SetVector<int32_t> kDimsSet(kDims.begin(), kDims.end()); |
| 68 | + |
| 69 | + // Bail out if it is not propagable. |
| 70 | + // TODO: Relax the check to allow transforming innermost reduction dimensions. |
| 71 | + // We need to revisit the matmul_k encoding semantic. |
| 72 | + SmallVector<ReassociationIndices, 4> reassociationMaps = |
| 73 | + collapseOp.getReassociationIndices(); |
| 74 | + for (int32_t k : kDims) { |
| 75 | + if (reassociationMaps[k].size() != 1) { |
| 76 | + return rewriter.notifyMatchFailure( |
| 77 | + encodingOp, |
| 78 | + "expected collaps_shape ops to not transform k dimensions"); |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + // Get a mapping from original iteration space to expanded iteration space. |
| 83 | + SmallVector<int32_t> newKDims; |
| 84 | + for (int32_t kDim : kDims) { |
| 85 | + newKDims.append(reassociationMaps[kDim].begin(), |
| 86 | + reassociationMaps[kDim].end()); |
| 87 | + } |
| 88 | + |
| 89 | + // Create the new encoding op. |
| 90 | + MLIRContext *ctx = rewriter.getContext(); |
| 91 | + auto newEncodingAttr = IREE::Encoding::MatmulKAttr::get(ctx, newKDims); |
| 92 | + RankedTensorType newEncodingType = |
| 93 | + collapseOp.getSrcType().cloneWithEncoding(newEncodingAttr); |
| 94 | + Value newEncodingOp = rewriter.create<IREE::Encoding::SetEncodingOp>( |
| 95 | + encodingOp.getLoc(), newEncodingType, collapseOp.getSrc()); |
| 96 | + Value newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( |
| 97 | + collapseOp.getLoc(), encodingOp.getResultType(), newEncodingOp, |
| 98 | + collapseOp.getReassociationIndices()); |
| 99 | + rewriter.replaceOp(encodingOp, newCollapseOp); |
| 100 | + return success(); |
| 101 | +} |
| 102 | + |
| 103 | +void PropagateEncodingsPass::runOnOperation() { |
| 104 | + mlir::FunctionOpInterface funcOp = getOperation(); |
| 105 | + MLIRContext *ctx = &getContext(); |
| 106 | + RewritePatternSet propagationPatterns(ctx); |
| 107 | + propagationPatterns.insert<SwapEncodingOpWithTensorCollapseShapeOp>(ctx); |
| 108 | + GreedyRewriteConfig config; |
| 109 | + config.fold = true; |
| 110 | + config.cseConstants = false; |
| 111 | + if (failed(applyPatternsGreedily(funcOp, std::move(propagationPatterns), |
| 112 | + config))) { |
| 113 | + funcOp.emitOpError("failed to propagate encodings"); |
| 114 | + return signalPassFailure(); |
| 115 | + } |
| 116 | +} |
| 117 | + |
| 118 | +} // namespace mlir::iree_compiler::DispatchCreation |
0 commit comments