|
16 | 16 | #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
17 | 17 | #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
18 | 18 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
| 19 | +#include "mlir/IR/AffineExprVisitor.h" |
19 | 20 | #include "mlir/IR/Matchers.h"
|
20 | 21 | #include "mlir/IR/OpImplementation.h"
|
21 | 22 | #include "mlir/IR/PatternMatch.h"
|
22 | 23 |
|
23 | 24 | #include "llvm/ADT/DenseMap.h"
|
24 | 25 | #include "llvm/ADT/SetVector.h"
|
| 26 | +#include "llvm/ADT/SmallSet.h" |
25 | 27 | #include "llvm/ADT/StringSet.h"
|
26 | 28 | #include "llvm/Support/FormatVariadic.h"
|
27 | 29 | #include "llvm/Support/MathExtras.h"
|
@@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
|
86 | 88 | return res;
|
87 | 89 | }
|
88 | 90 |
|
| 91 | +/// Visitor to check if any of the given set of positions from AffineDimExprs |
| 92 | +/// are used within an AffineExpr. |
| 93 | +struct HasAffineDimExprVisitor |
| 94 | + : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { |
| 95 | + HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions) |
| 96 | + : positions(positions) {} |
| 97 | + |
| 98 | + bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { |
| 99 | + return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); |
| 100 | + } |
| 101 | + |
| 102 | + bool visitDimExpr(AffineDimExpr dimExpr) { |
| 103 | + return positions.count(dimExpr.getPosition()); |
| 104 | + } |
| 105 | + |
| 106 | + bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } |
| 107 | + |
| 108 | + bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } |
| 109 | + |
| 110 | +private: |
| 111 | + llvm::SmallSet<unsigned, 4> positions; |
| 112 | +}; |
| 113 | + |
| 114 | +Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, |
| 115 | + Location loc, |
| 116 | + unsigned resultIdx, |
| 117 | + unsigned dim) { |
| 118 | + // An example that helps understand the logic below. |
| 119 | + // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) |
| 120 | + // We want to express the shape of dim 0 of O in terms of shape of the inputs. |
| 121 | + // This is achieved as follows. |
| 122 | + // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) |
| 123 | + // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) |
| 124 | + // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) |
| 125 | + // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) |
| 126 | + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) |
| 127 | + AffineMap loopsToShapesMap = getLoopsToShapesMap(); |
| 128 | + |
| 129 | + // Find the position in the above map that represents the shape of the |
| 130 | + // result:dim being inferred. |
| 131 | + Optional<unsigned> resultDimSubMapPos = |
| 132 | + getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); |
| 133 | + if (!resultDimSubMapPos) |
| 134 | + return {}; |
| 135 | + |
| 136 | + /// From loopsToShapesMap extract the submap that represents the shape of the |
| 137 | + /// (resultIdx, dim) needed |
| 138 | + AffineMap loopToResultDimShapeMap = |
| 139 | + loopsToShapesMap.getSubMap(*resultDimSubMapPos); |
| 140 | + AffineMap operandShapesToResultDimMap = |
| 141 | + loopToResultDimShapeMap.compose(getShapesToLoopsMap()); |
| 142 | + |
| 143 | + // Check that the result dim map does not contain the positions corresponding |
| 144 | + // to the outputs. |
| 145 | + llvm::SmallSet<unsigned, 4> outputDims; |
| 146 | + unsigned outputDimPosStart = |
| 147 | + getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); |
| 148 | + unsigned outputDimPosEnd = |
| 149 | + getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, |
| 150 | + getOutputOpOperands() |
| 151 | + .back() |
| 152 | + .get() |
| 153 | + .getType() |
| 154 | + .cast<ShapedType>() |
| 155 | + .getRank() - |
| 156 | + 1) |
| 157 | + .getValue(); |
| 158 | + llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd), |
| 159 | + [&outputDims](unsigned dim) { outputDims.insert(dim); }); |
| 160 | + HasAffineDimExprVisitor checkDimExpr(outputDims); |
| 161 | + if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) |
| 162 | + return llvm::None; |
| 163 | + return applyMapToValues(b, loc, operandShapesToResultDimMap, |
| 164 | + createFlatListOfOperandDims(b, loc))[0]; |
| 165 | +} |
| 166 | + |
89 | 167 | /// Forward declarations.
|
90 | 168 | template <typename NamedStructuredOpType>
|
91 | 169 | static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
@@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
|
2022 | 2100 | return success();
|
2023 | 2101 | }
|
2024 | 2102 | };
|
| 2103 | + |
| 2104 | +/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) |
| 2105 | +/// with std.dim operations that use one of the arguments. For example, |
| 2106 | +/// |
| 2107 | +/// %0 = linalg.matmul ins(%arg0, %arg1, ...) |
| 2108 | +/// %1 = dim %0, %c0 |
| 2109 | +/// |
| 2110 | +/// with |
| 2111 | +/// |
| 2112 | +/// %1 = dim %arg0, %c0 |
| 2113 | +/// |
| 2114 | +/// where possible. With this the result of the `linalg.matmul` is not used in |
| 2115 | +/// dim operations. If the value produced is replaced with another value (say by |
| 2116 | +/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of |
| 2117 | +/// used in a dim op that would prevent the DCE of this op. |
| 2118 | +struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> { |
| 2119 | + using OpRewritePattern<DimOp>::OpRewritePattern; |
| 2120 | + |
| 2121 | + LogicalResult matchAndRewrite(DimOp dimOp, |
| 2122 | + PatternRewriter &rewriter) const override { |
| 2123 | + Value dimValue = dimOp.memrefOrTensor(); |
| 2124 | + Optional<int64_t> dimIndex = dimOp.getConstantIndex(); |
| 2125 | + if (!dimIndex) |
| 2126 | + return failure(); |
| 2127 | + auto linalgOp = dimValue.getDefiningOp<LinalgOp>(); |
| 2128 | + if (!linalgOp) |
| 2129 | + return failure(); |
| 2130 | + |
| 2131 | + unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber(); |
| 2132 | + Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes( |
| 2133 | + rewriter, dimOp.getLoc(), resultIndex, |
| 2134 | + static_cast<unsigned>(*dimIndex)); |
| 2135 | + if (!operandDimValue) { |
| 2136 | + // Its always possible to replace using the corresponding `outs` |
| 2137 | + // parameter. |
| 2138 | + operandDimValue = rewriter.create<DimOp>( |
| 2139 | + dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex); |
| 2140 | + } |
| 2141 | + rewriter.replaceOp(dimOp, *operandDimValue); |
| 2142 | + return success(); |
| 2143 | + } |
| 2144 | +}; |
| 2145 | + |
2025 | 2146 | } // namespace
|
2026 | 2147 |
|
2027 | 2148 | namespace {
|
@@ -2166,34 +2287,14 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
|
2166 | 2287 | return success();
|
2167 | 2288 | }
|
2168 | 2289 | };
|
2169 |
| - |
2170 |
| -/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg |
2171 |
| -/// with the corresponding output tensor argument of the linalg op. |
2172 |
| -struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> { |
2173 |
| - using OpRewritePattern<DimOp>::OpRewritePattern; |
2174 |
| - |
2175 |
| - LogicalResult matchAndRewrite(DimOp dimOp, |
2176 |
| - PatternRewriter &rewriter) const override { |
2177 |
| - Value dimOpArg = dimOp.memrefOrTensor(); |
2178 |
| - auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>(); |
2179 |
| - if (!linalgOp) |
2180 |
| - return failure(); |
2181 |
| - |
2182 |
| - auto results = linalgOp.getOperation()->getResults(); |
2183 |
| - int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg)); |
2184 |
| - auto outputTensors = linalgOp.getOutputTensors(); |
2185 |
| - rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index()); |
2186 |
| - return success(); |
2187 |
| - } |
2188 |
| -}; |
2189 | 2290 | } // namespace
|
2190 | 2291 |
|
2191 | 2292 | #define CANONICALIZERS_AND_FOLDERS(XXX) \
|
2192 | 2293 | void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
|
2193 | 2294 | MLIRContext *context) { \
|
2194 | 2295 | results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
|
2195 | 2296 | RemoveIdentityLinalgOps>(); \
|
2196 |
| - results.insert<ReplaceDimOfLinalgResult>(context); \ |
| 2297 | + results.insert<ReplaceDimOfLinalgOpResult>(context); \ |
2197 | 2298 | } \
|
2198 | 2299 | \
|
2199 | 2300 | LogicalResult XXX::fold(ArrayRef<Attribute>, \
|
|
0 commit comments