From 2ddaa8ecb824e263730b1747b2762c2526b0773e Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 1 Apr 2025 14:25:16 -0700 Subject: [PATCH 1/4] add canonicalizer --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 23 ++++++++++++++++++++-- mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 98d98f067de14..68b4c26880141 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5670,6 +5670,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern { } }; +// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp. +class ShapeCastPoisonFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + + if (!shapeCastOp.getSource().getDefiningOp()) + return failure(); + + rewriter.replaceOpWithNewOp(shapeCastOp, + shapeCastOp.getType()); + return success(); + } +}; + /// Helper function that computes a new vector type based on the input vector /// type by removing the trailing one dims: /// @@ -5828,8 +5845,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b7db8ec834be7..72064fb42741a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// CHECK-LABEL: shape_cast_poison +// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> +// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> +func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) { + %poison = ub.poison : vector<5x4x2xf32> + %poison_1 = ub.poison : vector<12x2xi32> + %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32> + %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32> + return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32> +} + +// ----- + // CHECK-LABEL: extract_strided_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32> From 99ed6355272915d893db4a38fdfd7c3ba0fa0daa Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 7 Apr 2025 09:04:44 -0700 Subject: [PATCH 2/4] use folders where possible (replace 2 canons) Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 56 +++++++----------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 68b4c26880141..53fc47a6d6ef5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5643,49 +5643,24 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - return {}; -} - -namespace { -// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. -class ShapeCastConstantFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto constantOp = - shapeCastOp.getSource().getDefiningOp(); - if (!constantOp) - return failure(); - // Only handle splat for now. - auto dense = llvm::dyn_cast(constantOp.getValue()); - if (!dense) - return failure(); - auto newAttr = - DenseElementsAttr::get(llvm::cast(shapeCastOp.getType()), - dense.getSplatValue()); - rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); - return success(); + // Replace shape_cast(arith.constant) with arith.constant. Currently only + // handles splat constants. + if (auto constantOp = getSource().getDefiningOp()) { + if (auto dense = llvm::dyn_cast(constantOp.getValue())) { + return DenseElementsAttr::get(cast(getType()), + dense.getSplatValue()); + } } -}; -// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp. -class ShapeCastPoisonFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { + // Replace shape_cast(poison) with poison. + if (getSource().getDefiningOp()) { + return ub::PoisonAttr::get(getContext()); + } - if (!shapeCastOp.getSource().getDefiningOp()) - return failure(); + return {}; +} - rewriter.replaceOpWithNewOp(shapeCastOp, - shapeCastOp.getType()); - return success(); - } -}; +namespace { /// Helper function that computes a new vector type based on the input vector /// type by removing the trailing one dims: @@ -5846,8 +5821,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results - .add( + .add( context); } From eb5b9d7046b8c7ae0b30f8872ac4a593074af890 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 7 Apr 2025 10:16:20 -0700 Subject: [PATCH 3/4] simplify folding of const/poison Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 +++++++++++------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 53fc47a6d6ef5..0ac22969629da 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include #include @@ -5611,28 +5612,27 @@ LogicalResult ShapeCastOp::verify() { } OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { + // No-op shape cast. - if (getSource().getType() == getResult().getType()) + if (getSource().getType() == getType()) return getSource(); // Canceling shape casts. if (auto otherOp = getSource().getDefiningOp()) { - if (getResult().getType() == otherOp.getSource().getType()) - return otherOp.getSource(); - // Only allows valid transitive folding. - VectorType srcType = llvm::cast(otherOp.getSource().getType()); - VectorType resultType = llvm::cast(getResult().getType()); - if (srcType.getRank() < resultType.getRank()) { - if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) + // Only allows valid transitive folding (expand/collapse dimensions). + VectorType srcType = otherOp.getSource().getType(); + if (getType() == srcType) + return otherOp.getSource(); + if (srcType.getRank() < getType().getRank()) { + if (!isValidShapeCast(srcType.getShape(), getType().getShape())) return {}; - } else if (srcType.getRank() > resultType.getRank()) { - if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) + } else if (srcType.getRank() > getType().getRank()) { + if (!isValidShapeCast(getType().getShape(), srcType.getShape())) return {}; } else { return {}; } - setOperand(otherOp.getSource()); return getResult(); } @@ -5643,17 +5643,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - // Replace shape_cast(arith.constant) with arith.constant. Currently only - // handles splat constants. - if (auto constantOp = getSource().getDefiningOp()) { - if (auto dense = llvm::dyn_cast(constantOp.getValue())) { - return DenseElementsAttr::get(cast(getType()), - dense.getSplatValue()); - } + // shape_cast(constant) -> constant + if (auto splatAttr = + llvm::dyn_cast_if_present(adaptor.getSource())) { + return DenseElementsAttr::get(getType(), + splatAttr.getSplatValue()); } - // Replace shape_cast(poison) with poison. - if (getSource().getDefiningOp()) { + // shape_cast(poison) -> poison + if (llvm::dyn_cast_if_present(adaptor.getSource())) { return ub::PoisonAttr::get(getContext()); } From d47886c4c6516b877baeacade6a29b80403aad9b Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 9 Apr 2025 13:06:13 -0700 Subject: [PATCH 4/4] use resultType Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 0ac22969629da..59f3b788cebed 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5617,18 +5617,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { if (getSource().getType() == getType()) return getSource(); + VectorType resultType = getType(); + // Canceling shape casts. if (auto otherOp = getSource().getDefiningOp()) { // Only allows valid transitive folding (expand/collapse dimensions). VectorType srcType = otherOp.getSource().getType(); - if (getType() == srcType) + if (resultType == srcType) return otherOp.getSource(); - if (srcType.getRank() < getType().getRank()) { - if (!isValidShapeCast(srcType.getShape(), getType().getShape())) + if (srcType.getRank() < resultType.getRank()) { + if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) return {}; - } else if (srcType.getRank() > getType().getRank()) { - if (!isValidShapeCast(getType().getShape(), srcType.getShape())) + } else if (srcType.getRank() > resultType.getRank()) { + if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) return {}; } else { return {}; @@ -5639,14 +5641,14 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // Cancelling broadcast and shape cast ops. if (auto bcastOp = getSource().getDefiningOp()) { - if (bcastOp.getSourceType() == getType()) + if (bcastOp.getSourceType() == resultType) return bcastOp.getSource(); } // shape_cast(constant) -> constant if (auto splatAttr = llvm::dyn_cast_if_present(adaptor.getSource())) { - return DenseElementsAttr::get(getType(), + return DenseElementsAttr::get(resultType, splatAttr.getSplatValue()); }