@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
57785778
57795779// / Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
57805780// / This only applies when the shape of the broadcast source
5781- // / 1. is a suffix of the shape of the result (i.e. when broadcast without
5782- // / reshape is expressive enough to capture the result in a single op), or
5781+ // / 1. can be broadcast directly to the final shape, or
57835782// / 2. has the same element count as the shape cast result.
57845783class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57855784public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57925791 if (!broadcastOp)
57935792 return failure ();
57945793
5795- ArrayRef<int64_t > broadcastSourceShape;
5796- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ()))
5797- broadcastSourceShape = srcType.getShape ();
5798- ArrayRef<int64_t > shapeCastTargetShape =
5799- shapeCastOp.getResultVectorType ().getShape ();
5800-
5801- // If `broadcastSourceShape` is a suffix of the result, we can just replace
5802- // with a broadcast to the final shape.
5803- if (broadcastSourceShape ==
5804- shapeCastTargetShape.take_back (broadcastSourceShape.size ())) {
5805- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5806- shapeCastOp, shapeCastOp.getResultVectorType (),
5807- broadcastOp.getSource ());
5808- return success ();
5794+ {
5795+ VectorType dstType = shapeCastOp.getResultVectorType ();
5796+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ());
5797+ bool isScalar = !srcType;
5798+ if (isScalar || isBroadcastableTo (srcType, dstType) ==
5799+ BroadcastableToResult::Success) {
5800+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
5801+ shapeCastOp, dstType, broadcastOp.getSource ());
5802+ return success ();
5803+ }
58095804 }
58105805
5811- // Otherwise, if the final result has the same element count, we can replace
5812- // with a shape cast.
5806+ // If the final result has the same element count, we can replace with a
5807+ // shape cast.
58135808 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType ())) {
58145809 if (srcType.getNumElements () ==
58155810 shapeCastOp.getResultVectorType ().getNumElements ()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60796074 }
60806075};
60816076
6082- // Folds transpose(broadcast(<scalar>)) into brodcast (<scalar>).
6077+ // Folds transpose(broadcast(<scalar>)) into broadcast (<scalar>).
60836078struct FoldTransposedScalarBroadcast final
60846079 : public OpRewritePattern<vector::TransposeOp> {
60856080 using OpRewritePattern::OpRewritePattern;
0 commit comments