@@ -5768,11 +5768,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
57685768 }
57695769};
57705770
5771- // / Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5772- // / This only applies when the shape of the broadcast source
5773- // / 1. is a suffix of the shape of the result (i.e. when broadcast without
5774- // / reshape is expressive enough to capture the result in a single op), or
5775- // / 2. has the same element count as the shape cast result.
5771+ // / Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5772+ // / i) Y = ShapeCast(X), or
5773+ // / ii) Y = Broadcast(X)
5774+ // / If both (i) and (ii) are possible, (i) is chosen.
57765775class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57775776public:
57785777 using OpRewritePattern::OpRewritePattern;
@@ -5784,26 +5783,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57845783 if (!broadcastOp)
57855784 return failure ();
57865785
5787- ArrayRef<int64_t > broadcastSourceShape;
5788- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ()))
5789- broadcastSourceShape = srcType.getShape ();
5790- ArrayRef<int64_t > shapeCastTargetShape =
5791- shapeCastOp.getResultVectorType ().getShape ();
5792-
5793- // If `broadcastSourceShape` is a suffix of the result, we can just replace
5794- // with a broadcast to the final shape.
5795- if (broadcastSourceShape ==
5796- shapeCastTargetShape.take_back (broadcastSourceShape.size ())) {
5797- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5798- shapeCastOp, shapeCastOp.getResultVectorType (),
5799- broadcastOp.getSource ());
5800- return success ();
5801- }
5802-
5803- // Otherwise, if the final result has the same element count, we can replace
5804- // with a shape cast.
5805- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ())) {
5806- if (srcType.getNumElements () ==
5786+ auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
5787+ bool srcIsScalar = !srcVectorType;
5788+
5789+ // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5790+ // Example:
5791+ // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5792+ // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5793+ // to
5794+ // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5795+ if (srcVectorType) {
5796+ if (srcVectorType.getNumElements () ==
58075797 shapeCastOp.getResultVectorType ().getNumElements ()) {
58085798 rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
58095799 shapeCastOp, shapeCastOp.getResultVectorType (),
@@ -5812,6 +5802,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58125802 }
58135803 }
58145804
5805+ // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5806+ // Example
5807+ // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
5808+ // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
5809+ // to
5810+ // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
5811+ VectorType dstVectorType = shapeCastOp.getResultVectorType ();
5812+ if (srcIsScalar || isBroadcastableTo (srcVectorType, dstVectorType) ==
5813+ BroadcastableToResult::Success) {
5814+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5815+ shapeCastOp, dstVectorType, broadcastOp.getSource ());
5816+ return success ();
5817+ }
58155818 return failure ();
58165819 }
58175820};
@@ -6072,7 +6075,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60726075 }
60736076};
60746077
6075- // Folds transpose(broadcast(<scalar>)) into brodcast (<scalar>).
6078+ // Folds transpose(broadcast(<scalar>)) into broadcast (<scalar>).
60766079struct FoldTransposedScalarBroadcast final
60776080 : public OpRewritePattern<vector::TransposeOp> {
60786081 using OpRewritePattern::OpRewritePattern;
0 commit comments