Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Apr 8, 2025

Add additional cases of this canonicalization, by checking the 'source of truth' function isBroadcastableTo to check when it is possible to broadcast directly to the shape resulting from the shape_cast.

@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Add additional cases of this canonicalization, by checking the 'source of truth' function isBroadcastableTo to check when it is possible to broadcast directly to the shape resulting from the shape_cast.


Full diff: https://github.com/llvm/llvm-project/pull/134939.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+14-19)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+25)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..c6d8ec1e1cf69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
 
 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
 /// This only applies when the shape of the broadcast source
-/// 1. is a suffix of the shape of the result (i.e. when broadcast without
-///    reshape is expressive enough to capture the result in a single op), or
+/// 1. can be broadcast directly to the final shape, or
 /// 2. has the same element count as the shape cast result.
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     if (!broadcastOp)
       return failure();
 
-    ArrayRef<int64_t> broadcastSourceShape;
-    if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
-      broadcastSourceShape = srcType.getShape();
-    ArrayRef<int64_t> shapeCastTargetShape =
-        shapeCastOp.getResultVectorType().getShape();
-
-    // If `broadcastSourceShape` is a suffix of the result, we can just replace
-    // with a broadcast to the final shape.
-    if (broadcastSourceShape ==
-        shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, shapeCastOp.getResultVectorType(),
-          broadcastOp.getSource());
-      return success();
+    {
+      VectorType dstType = shapeCastOp.getResultVectorType();
+      auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+      bool isScalar = !srcType;
+      if (isScalar || isBroadcastableTo(srcType, dstType) ==
+                          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+            shapeCastOp, dstType, broadcastOp.getSource());
+        return success();
+      }
     }
 
-    // Otherwise, if the final result has the same element count, we can replace
-    // with a shape cast.
+    // If the final result has the same element count, we can replace with a
+    // shape cast.
     if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
       if (srcType.getNumElements() ==
           shapeCastOp.getResultVectorType().getNumElements()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
-// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
 struct FoldTransposedScalarBroadcast final
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d7617d79b5cbf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1017,6 +1017,31 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
+//       CHECK:   vector.broadcast
+//  CHECK-SAME:   f32 to vector<3x4x1xf32>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<12xf32>
+  %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
+  return %1 : vector<3x4x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
+//       CHECK:   vector.broadcast
+//  CHECK-SAME:   vector<1x1xi8> to vector<1x1x6x1x4xi8>
+//   CHECK-NOT:   vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+  %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+  return %1 : vector<1x1x6x1x4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@newling newling force-pushed the vector_broadcast_canonicalization branch from d5d59c2 to b58b837 Compare April 9, 2025 20:42
@newling
Copy link
Contributor Author

newling commented Apr 9, 2025

Thanks for your suggested improvements @dcaballe and @banach-space. I've hopefully addressed them all

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean-up, thank you!

I've left one [nit], but that's non-blocking so approving as is. LGTM

@newling newling force-pushed the vector_broadcast_canonicalization branch from 151bbfd to 408842c Compare April 10, 2025 16:22
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks!

Comment on lines +5794 to 5807
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;

// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
// Example:
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
// to
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
if (srcVectorType) {
if (srcVectorType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok to land for now, but this should be a folder not a canonicalization pattern.

@Groverkss Groverkss merged commit 409def2 into llvm:main Apr 11, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants