diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 98d98f067de14..59f3b788cebed 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include #include @@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() { } OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { + // No-op shape cast. - if (getSource().getType() == getResult().getType()) + if (getSource().getType() == getType()) return getSource(); + VectorType resultType = getType(); + // Canceling shape casts. if (auto otherOp = getSource().getDefiningOp()) { - if (getResult().getType() == otherOp.getSource().getType()) - return otherOp.getSource(); - // Only allows valid transitive folding. - VectorType srcType = llvm::cast(otherOp.getSource().getType()); - VectorType resultType = llvm::cast(getResult().getType()); + // Only allows valid transitive folding (expand/collapse dimensions). + VectorType srcType = otherOp.getSource().getType(); + if (resultType == srcType) + return otherOp.getSource(); if (srcType.getRank() < resultType.getRank()) { if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; @@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { } else { return {}; } - setOperand(otherOp.getSource()); return getResult(); } // Cancelling broadcast and shape cast ops. if (auto bcastOp = getSource().getDefiningOp()) { - if (bcastOp.getSourceType() == getType()) + if (bcastOp.getSourceType() == resultType) return bcastOp.getSource(); } + // shape_cast(constant) -> constant + if (auto splatAttr = + llvm::dyn_cast_if_present(adaptor.getSource())) { + return DenseElementsAttr::get(resultType, + splatAttr.getSplatValue()); + } + + // shape_cast(poison) -> poison + if (llvm::dyn_cast_if_present(adaptor.getSource())) { + return ub::PoisonAttr::get(getContext()); + } + return {}; } namespace { -// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. -class ShapeCastConstantFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto constantOp = - shapeCastOp.getSource().getDefiningOp(); - if (!constantOp) - return failure(); - // Only handle splat for now. - auto dense = llvm::dyn_cast(constantOp.getValue()); - if (!dense) - return failure(); - auto newAttr = - DenseElementsAttr::get(llvm::cast(shapeCastOp.getType()), - dense.getSplatValue()); - rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); - return success(); - } -}; /// Helper function that computes a new vector type based on the input vector /// type by removing the trailing one dims: @@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b7db8ec834be7..72064fb42741a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// CHECK-LABEL: shape_cast_poison +// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> +// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> +func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) { + %poison = ub.poison : vector<5x4x2xf32> + %poison_1 = ub.poison : vector<12x2xi32> + %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32> + %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32> + return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32> +} + +// ----- + // CHECK-LABEL: extract_strided_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>