1616#include < cstdint>
1717#include < functional>
1818#include < optional>
19- #include < type_traits>
2019
21- #include " mlir/Dialect/Affine/IR/AffineOps.h"
2220#include " mlir/Dialect/Arith/IR/Arith.h"
2321#include " mlir/Dialect/Arith/Utils/Utils.h"
24- #include " mlir/Dialect/Linalg/IR/Linalg.h"
2522#include " mlir/Dialect/MemRef/IR/MemRef.h"
2623#include " mlir/Dialect/SCF/IR/SCF.h"
27- #include " mlir/Dialect/Tensor/IR/Tensor.h"
2824#include " mlir/Dialect/Utils/IndexingUtils.h"
2925#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
3026#include " mlir/Dialect/Vector/IR/VectorOps.h"
3127#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
3228#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
33- #include " mlir/IR/BuiltinAttributeInterfaces.h"
3429#include " mlir/IR/BuiltinTypes.h"
35- #include " mlir/IR/ImplicitLocOpBuilder.h"
3630#include " mlir/IR/Location.h"
3731#include " mlir/IR/Matchers.h"
3832#include " mlir/IR/PatternMatch.h"
3933#include " mlir/IR/TypeUtilities.h"
40- #include " mlir/Interfaces/VectorInterfaces.h"
4134
42- #include " llvm/ADT/DenseSet.h"
43- #include " llvm/ADT/MapVector.h"
4435#include " llvm/ADT/STLExtras.h"
45- #include " llvm/Support/CommandLine.h"
46- #include " llvm/Support/Debug.h"
4736#include " llvm/Support/FormatVariadic.h"
48- #include " llvm/Support/raw_ostream.h"
4937
5038#define DEBUG_TYPE " vector-to-vector"
5139
@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
7159
7260namespace {
7361
74- // / ShapeCastOpFolder folds cancelling ShapeCastOps away.
75- //
76- // Example:
77- //
78- // The following MLIR with cancelling ShapeCastOps:
79- //
80- // %0 = source : vector<5x4x2xf32>
81- // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
82- // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
83- // %3 = user %2 : vector<5x4x2xf32>
84- //
85- // Should canonicalize to the following:
86- //
87- // %0 = source : vector<5x4x2xf32>
88- // %1 = user %0 : vector<5x4x2xf32>
89- //
90- struct ShapeCastOpFolder : public OpRewritePattern <vector::ShapeCastOp> {
91- using OpRewritePattern::OpRewritePattern;
92-
93- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
94- PatternRewriter &rewriter) const override {
95- // Check if 'shapeCastOp' has vector source/result type.
96- auto sourceVectorType =
97- dyn_cast_or_null<VectorType>(shapeCastOp.getSource ().getType ());
98- auto resultVectorType =
99- dyn_cast_or_null<VectorType>(shapeCastOp.getResult ().getType ());
100- if (!sourceVectorType || !resultVectorType)
101- return failure ();
102-
103- // Check if shape cast op source operand is also a shape cast op.
104- auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
105- shapeCastOp.getSource ().getDefiningOp ());
106- if (!sourceShapeCastOp)
107- return failure ();
108- auto operandSourceVectorType =
109- cast<VectorType>(sourceShapeCastOp.getSource ().getType ());
110- auto operandResultVectorType = sourceShapeCastOp.getType ();
111-
112- // Check if shape cast operations invert each other.
113- if (operandSourceVectorType != resultVectorType ||
114- operandResultVectorType != sourceVectorType)
115- return failure ();
116-
117- rewriter.replaceOp (shapeCastOp, sourceShapeCastOp.getSource ());
118- return success ();
119- }
120- };
121-
12262// / Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
12363// / Ex:
12464// / ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
21132053 patterns.add <FoldI1Select>(patterns.getContext (), benefit);
21142054}
21152055
2116- void mlir::vector::populateShapeCastFoldingPatterns (RewritePatternSet &patterns,
2117- PatternBenefit benefit) {
2118- patterns.add <ShapeCastOpFolder>(patterns.getContext (), benefit);
2119- }
2120-
21212056void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
21222057 RewritePatternSet &patterns, PatternBenefit benefit) {
21232058 // TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
21262061 // * better naming to distinguish this and
21272062 // populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
21282063 patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2129- DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
2130- patterns.getContext (), benefit);
2064+ DropUnitDimsFromTransposeOp>(patterns.getContext (), benefit);
21312065}
21322066
21332067void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments