Skip to content

Conversation

@CoTinker
Copy link
Contributor

Back to back linalg.broadcast can be rewritten to a single broadcast.

@llvmbot
Copy link
Member

llvmbot commented Jul 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

Changes

Back to back linalg.broadcast can be rewritten to a single broadcast.


Full diff: https://github.com/llvm/llvm-project/pull/150825.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+32-1)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+46)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b661781f10f..c1544d571a938 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2292,9 +2292,40 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+/// Fold broadcast with broadcast.
+struct FoldBroadcastWithBroadcast : OpRewritePattern<linalg::BroadcastOp> {
+  using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+    if (!defBroadcastOp)
+      return failure();
+    ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    SmallVector<int64_t> foldedDims(dimensions);
+    Value init = broadcastOp.getInit();
+    int64_t initRank = init.getType().getRank();
+    // Mapping from input dims to init dims.
+    SmallVector<int64_t> dimMap;
+    for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+      if (!llvm::is_contained(dimensions, dim))
+        dimMap.push_back(dim);
+    }
+    for (auto dim : defDimensions)
+      foldedDims.push_back(dimMap[dim]);
+
+    llvm::sort(foldedDims);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+    return success();
+  }
+};
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcastWithBroadcast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4de884..0e8546404f9e2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
 
 // -----
 
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x3xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x3xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x3xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [2]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+//  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+//       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) permutation = [1, 2]
+//   CHECK-NOT:   linalg.broadcast
+//       CHECK:   return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+                                    %init1: tensor<2x4xf32>,
+                                    %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %broadcast1 = linalg.broadcast
+      ins(%input: tensor<2xf32>)
+      outs(%init1: tensor<2x4xf32>)
+      dimensions = [1]
+  %broadcast2 = linalg.broadcast
+      ins(%broadcast1: tensor<2x4xf32>)
+      outs(%init2: tensor<2x3x4xf32>)
+      dimensions = [1]
+  func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
   %transpose = linalg.transpose

@CoTinker CoTinker force-pushed the fold_brc_with_brc branch from 11eb0e0 to 5b388fb Compare July 27, 2025 11:05
Back to back `linalg.broadcast` can be rewritten to a single broadcast.
@CoTinker CoTinker force-pushed the fold_brc_with_brc branch from 5b388fb to 23f53b5 Compare July 27, 2025 11:18
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit, but lgtm, thanks!

@CoTinker CoTinker merged commit b30034d into llvm:main Jul 29, 2025
9 checks passed
@CoTinker CoTinker deleted the fold_brc_with_brc branch July 29, 2025 11:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants