@@ -2182,6 +2182,91 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
21822182 }
21832183};
21842184
2185+ // / For example,
2186+ // / ```
2187+ // / %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
2188+ // / vector<2x2x1xf32>
2189+ // / ```
2190+ // / becomes
2191+ // / ```
2192+ // / %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
2193+ // / ```
2194+ struct TransposeToShapeCast final
2195+ : public OpRewritePattern<vector::TransposeOp> {
2196+ using OpRewritePattern::OpRewritePattern;
2197+ LogicalResult matchAndRewrite (vector::TransposeOp transpose,
2198+ PatternRewriter &rewriter) const override {
2199+ if (!isOrderPreserving (transpose)) {
2200+ return rewriter.notifyMatchFailure (
2201+ transpose, " not order preserving, so not semantically a 'copy'" );
2202+ }
2203+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
2204+ transpose, transpose.getType (), transpose.getVector ());
2205+ return success ();
2206+ }
2207+ };
2208+
2209+ // / For example,
2210+ // / ```
2211+ // / %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2212+ // / ```
2213+ // / becomes
2214+ // / ```
2215+ // / %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2216+ // / ```
2217+ struct BroadcastToShapeCast final
2218+ : public OpRewritePattern<vector::BroadcastOp> {
2219+ using OpRewritePattern::OpRewritePattern;
2220+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcast,
2221+ PatternRewriter &rewriter) const override {
2222+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType ());
2223+ if (!sourceType) {
2224+ return rewriter.notifyMatchFailure (
2225+ broadcast, " source is a scalar, shape_cast doesn't support scalar" );
2226+ }
2227+
2228+ VectorType outType = broadcast.getType ();
2229+ if (sourceType.getNumElements () != outType.getNumElements ())
2230+ return failure ();
2231+
2232+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcast, outType,
2233+ broadcast.getSource ());
2234+ return success ();
2235+ }
2236+ };
2237+
2238+ // / For example,
2239+ // / ```
2240+ // / %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2241+ // / ```
2242+ // / becomes
2243+ // / ```
2244+ // / %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2245+ // / ```
2246+ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
2247+ using OpRewritePattern::OpRewritePattern;
2248+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
2249+ PatternRewriter &rewriter) const override {
2250+ VectorType sourceType = extractOp.getSourceVectorType ();
2251+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2252+ if (!outType)
2253+ return failure ();
2254+
2255+ // Negative values in `position` indicates poison, cannot convert to
2256+ // shape_cast
2257+ if (llvm::any_of (extractOp.getMixedPosition (),
2258+ [](OpFoldResult v) { return !isConstantIntValue (v, 0 ); }))
2259+ return failure ();
2260+
2261+ if (sourceType.getNumElements () != outType.getNumElements ())
2262+ return failure ();
2263+
2264+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, outType,
2265+ extractOp.getVector ());
2266+ return success ();
2267+ }
2268+ };
2269+
21852270} // namespace
21862271
21872272void mlir::vector::populateFoldArithExtensionPatterns (
@@ -2285,6 +2370,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
22852370 patterns.getContext ());
22862371}
22872372
2373+ void mlir::vector::populateConvertToShapeCastPatterns (
2374+ RewritePatternSet &patterns, PatternBenefit benefit) {
2375+ patterns
2376+ .insert <TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
2377+ patterns.getContext (), benefit);
2378+ }
2379+
22882380// ===----------------------------------------------------------------------===//
22892381// TableGen'd enum attribute definitions
22902382// ===----------------------------------------------------------------------===//
0 commit comments