@@ -5768,11 +5768,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
5768
5768
}
5769
5769
};
5770
5770
5771
- // / Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5772
- // / This only applies when the shape of the broadcast source
5773
- // / 1. is a suffix of the shape of the result (i.e. when broadcast without
5774
- // / reshape is expressive enough to capture the result in a single op), or
5775
- // / 2. has the same element count as the shape cast result.
5771
+ // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5772
+ // / i) Y = ShapeCast(X), or
5773
+ // / ii) Y = Broadcast(X)
5774
+ // / If both (i) and (ii) are possible, (i) is chosen.
5776
5775
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5777
5776
public:
5778
5777
using OpRewritePattern::OpRewritePattern;
@@ -5784,26 +5783,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5784
5783
if (!broadcastOp)
5785
5784
return failure ();
5786
5785
5787
- ArrayRef<int64_t > broadcastSourceShape;
5788
- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ()))
5789
- broadcastSourceShape = srcType.getShape ();
5790
- ArrayRef<int64_t > shapeCastTargetShape =
5791
- shapeCastOp.getResultVectorType ().getShape ();
5792
-
5793
- // If `broadcastSourceShape` is a suffix of the result, we can just replace
5794
- // with a broadcast to the final shape.
5795
- if (broadcastSourceShape ==
5796
- shapeCastTargetShape.take_back (broadcastSourceShape.size ())) {
5797
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5798
- shapeCastOp, shapeCastOp.getResultVectorType (),
5799
- broadcastOp.getSource ());
5800
- return success ();
5801
- }
5802
-
5803
- // Otherwise, if the final result has the same element count, we can replace
5804
- // with a shape cast.
5805
- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ())) {
5806
- if (srcType.getNumElements () ==
5786
+ auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
5787
+ bool srcIsScalar = !srcVectorType;
5788
+
5789
+ // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5790
+ // Example:
5791
+ // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5792
+ // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5793
+ // to
5794
+ // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5795
+ if (srcVectorType) {
5796
+ if (srcVectorType.getNumElements () ==
5807
5797
shapeCastOp.getResultVectorType ().getNumElements ()) {
5808
5798
rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
5809
5799
shapeCastOp, shapeCastOp.getResultVectorType (),
@@ -5812,6 +5802,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5812
5802
}
5813
5803
}
5814
5804
5805
+ // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5806
+ // Example
5807
+ // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
5808
+ // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
5809
+ // to
5810
+ // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
5811
+ VectorType dstVectorType = shapeCastOp.getResultVectorType ();
5812
+ if (srcIsScalar || isBroadcastableTo (srcVectorType, dstVectorType) ==
5813
+ BroadcastableToResult::Success) {
5814
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5815
+ shapeCastOp, dstVectorType, broadcastOp.getSource ());
5816
+ return success ();
5817
+ }
5815
5818
return failure ();
5816
5819
}
5817
5820
};
@@ -6072,7 +6075,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6072
6075
}
6073
6076
};
6074
6077
6075
- // Folds transpose(broadcast(<scalar>)) into brodcast (<scalar>).
6078
+ // Folds transpose(broadcast(<scalar>)) into broadcast (<scalar>).
6076
6079
struct FoldTransposedScalarBroadcast final
6077
6080
: public OpRewritePattern<vector::TransposeOp> {
6078
6081
using OpRewritePattern::OpRewritePattern;
0 commit comments