Skip to content

[mlir][vector] Canonicalize broadcast of shape_cast #150523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 8, 2025

Conversation

mshockwave
Copy link
Member

Fold broadcast(shape_cast(x)) into broadcast(x) if the type of x is compatible with broadcast's result type.

Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Min-Yih Hsu (mshockwave)

Changes

Fold broadcast(shape_cast(x)) into broadcast(x) if the type of x is compatible with broadcast's result type.


Full diff: https://github.com/llvm/llvm-project/pull/150523.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto srcShapeCast =
+            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+      VectorType srcType = srcShapeCast.getSourceVectorType();
+      VectorType destType = broadcastOp.getResultVectorType();
+      if (vector::isBroadcastableTo(srcType, destType) ==
+          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                                 srcShapeCast.getSource());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+//       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+  %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense % minor suggestions.

Comment on lines 2949 to 2960
if (auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
}
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Prefer early exits - helps reduce indentation.

Suggested change
if (auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
}
return failure();
auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@newling
Copy link
Contributor

newling commented Jul 25, 2025

Can you reuse isBroadcastLike

static bool isBroadcastLike(Operation *op) {
?

General rule is that is something can be a folder, it should be (i.e. on BroadcastOp::fold) https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

@mshockwave
Copy link
Member Author

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

Yeah you're right. Let me turn this PR draft and think about this.

@mshockwave mshockwave marked this pull request as draft July 25, 2025 21:38
Comment on lines 2942 to 2963
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type.
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();

VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a folder, not a rewrite pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rewrote it into a folder

@mshockwave mshockwave marked this pull request as ready for review August 6, 2025 00:19
@mshockwave mshockwave requested review from Groverkss and newling August 6, 2025 00:19
@mshockwave
Copy link
Member Author

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

I updated the algorithm to add a condition that the replicating dimensions have to be the same before and after the transformations.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think this is correct now. But added a suggestion which might simplify it.

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type and the broadcasted dimensions are the same.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as saying (where srcShape -> shapecastShape -> destShape)

  1. rank(srcShape) <= rank(destShape)
  2. srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same.

If so, would be more intuitive I think. If not, can you please provided a counterexample?

Copy link
Member Author

@mshockwave mshockwave Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can roughly breakdown this into five cases by how we shape_cast
(1) srcShape is "broken" up into multiple non-one dimensions. e.g. <4x1> -> <2x2>
(2) srcShape is prepended by one or more ones
(3) srcShape is appended by one or more ones
(4) One or more leading dimensions in srcShape were removed
(5) One or more trailing dimensions in srcShape were removed

Note that multiple cases could be applied at the same time. For instance <2x1> -> <1x2> is removing the trailing dimension before appending a new one.

Case (1) is easy: srcShape will never be broadcastable w.r.t destShape. Because the rule of broadcast effectively mandates the source dimensions to be a "subset" of destination dimensions, modulo dimensions that are one. And changing the dimension values will violate that.

I think case (2), (4) are conjugate. Because broadcasting at those prepended dimensions that are one is the same as broadcasting toward missing (leading) dimensions; similarly, broadcasting at missing leading dimensions is the same as broadcasting at ones that were once there. Therefore, they are allowed.

Case (3) and (5) are similar, both of them change the "neighboring" elements in the highest dimension -- an element either becomes or not become 'singleton'. For instance [A, B] turns into [[A], [B]] when we cast from <2> to <2x1>. In which case element A turn from having a neighbor B into singleton. Whether it's singleton or not is important, because an element that is not singleton will always be broadcasted with its neighbor. On the other hand, being singleton means that it could be replicated on its own. Since this alters the broadcasting behavior, once this appears -- even combined with other cases like <1x2> -> <2x1> mentioned earlier -- we could not do the folding. Note that this also coincides with my current rule -- the original replicated dimensions have to match with the new replicated dimensions.

The bottom line is: I think your new rule is correct, I'm gonna update to it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algorithm is now updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the analysis, looks good to me as does the new impl. I think rank(srcShape) <= rank(destShape) is sufficient, but actually the way you check with isBroadcastableTo will probably be more intuitive to future readers.

Comment on lines 1174 to 1178
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about dim-1 broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test cases for that.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the comprehensive testing.

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type and the broadcasted dimensions are the same.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the analysis, looks good to me as does the new impl. I think rank(srcShape) <= rank(destShape) is sufficient, but actually the way you check with isBroadcastableTo will probably be more intuitive to future readers.

@mshockwave mshockwave merged commit b4e8b8e into llvm:main Aug 8, 2025
9 checks passed
@mshockwave mshockwave deleted the patch/mlir/shapecast-broadcast branch August 8, 2025 16:25
Comment on lines +2869 to +2873
assert(all_of(srcShape.drop_back(numTrailingDims),
[](int64_t E) { return E == 1; }) &&
all_of(shapecastShape.drop_back(numTrailingDims),
[](int64_t E) { return E == 1; }) &&
"ill-formed shape_cast");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Unlike LLVM, we use camelCasel in MLIR for variable names. So, E -> e (rather confusing, I know). If you want to avoid e (less readable than E IMHO), you could try E -> dim 🤷🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants