Skip to content

Conversation

@someoneinjd
Copy link
Contributor

Fold broadcast into new elementwise Op which has affine-map attached.

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: someoneinjd (someoneinjd)

Changes

Fold broadcast into new elementwise Op which has affine-map attached.


Full diff: https://github.com/llvm/llvm-project/pull/167626.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp (+36)
  • (modified) mlir/test/Dialect/Linalg/elementwise/fold.mlir (+44)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
index b1c0c3b161b20..0acebabafa594 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -64,6 +64,41 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
   }
 };
 
+struct FoldBroadcastPattern : public OpRewritePattern<ElementwiseOp> {
+  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ElementwiseOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    SmallVector<Value> newIns;
+    SmallVector<AffineMap> newMaps;
+    for (OpOperand *operand : op.getDpsInputOperands()) {
+      AffineMap map = op.getMatchingIndexingMap(operand);
+      auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();
+
+      if (!map.isIdentity() || !broadcastOp) {
+        // push in original operand and its map.
+        newIns.push_back(operand->get());
+        newMaps.push_back(map);
+        continue;
+      }
+      newIns.push_back(broadcastOp.getInput());
+      // push in broadcastOp's broadcast map.
+      newMaps.push_back(broadcastOp.getMatchingIndexingMap(
+          broadcastOp.getDpsInputOperand(0)));
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    newMaps.push_back(op.getIndexingMapsArray().back());
+
+    rewriter.replaceOpWithNewOp<ElementwiseOp>(
+        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+        rewriter.getAffineMapArrayAttr(newMaps));
+    return success();
+  }
+};
+
 struct LinalgFoldIntoElementwisePass
     : public impl::LinalgFoldIntoElementwisePassBase<
           LinalgFoldIntoElementwisePass> {
@@ -84,4 +119,5 @@ struct LinalgFoldIntoElementwisePass
 void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
     RewritePatternSet &patterns) {
   patterns.add<FoldTransposePattern>(patterns.getContext());
+  patterns.add<FoldBroadcastPattern>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
index e83c32fb6a2cf..a967629e2cf29 100644
--- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -41,3 +41,47 @@ func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tens
                           outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//
+// CHECK:  func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME:       indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT:    return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_broadcasted(%A : tensor<8x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %empty = tensor.empty() : tensor<8x16x32xf32>
+  %broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty :  tensor<8x16x32xf32>) dimensions = [1]
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                          ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK:  func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT:  %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:              indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME:              ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT:  return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_broadcasted(%A : tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) ->  tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+  %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+  %broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty :  tensor<?x?xf32>) dimensions = [1] 
+  %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+                          ins(%A, %broadcasted_B : tensor<?x?xf32>,  tensor<?x?xf32>)
+                          outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}

AffineMap map = op.getMatchingIndexingMap(operand);
auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();

if (!map.isIdentity() || !broadcastOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - also in case the existing map is not the identity map, shouldn't it still compose with the broadcast map? So we could obtain the new map via composition.

Consider the case of elementwise(transpose(broadcast)) whereupon we fold the transpose into the elementwise. This pattern then won't fold in the broadcast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this only handles the identity map case.

I was modeling this new pattern directly after the existing FoldTransposePattern, which also seems to have this same limitation (it doesn't compose with non-identity maps).

I agree we should compose the maps as you suggested. Should I add that map composition logic here in this PR, or would you prefer a new, separate PR to add this functionality for both the new FoldBroadcastPattern and the existing FoldTransposePattern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc.

I see two options:

  1. Simple: Two patterns (trans/bcast) that apply only to the immediate producer and elides it. Recurrently running this pass would take care of arbitrarily complex chains, but it would have to compose the new map with the existing (non-ID) map.
  2. Complex: Multiple patterns that simplify chains of layout manipulation before merging with the op (elemwise/contract/generic), and then a simple apply to the argument directly (replace map with affine.apply map).

To me, the first option is much simpler, and iterative. It might not be as complete as the second, but we'll rarely see such complex cases that would need a catch all mega pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @rengolin "This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc."

transpose+broadcast + existing map should be seen as projected-permutation. Let me know if you want to go down that path -- I did something similar here -- https://github.com/llvm/llvm-project/blame/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp.
@MaheshRavishankar had pointed me in this direction originally.

@someoneinjd someoneinjd force-pushed the fold-broadcast-into-elementwise branch from fd9b2eb to f0d0ceb Compare November 13, 2025 01:09
@rengolin rengolin requested a review from javedabsar1 November 13, 2025 12:50
AffineMap map = op.getMatchingIndexingMap(operand);
auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();

if (!map.isIdentity() || !broadcastOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @rengolin "This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc."

transpose+broadcast + existing map should be seen as projected-permutation. Let me know if you want to go down that path -- I did something similar here -- https://github.com/llvm/llvm-project/blame/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp.
@MaheshRavishankar had pointed me in this direction originally.

@someoneinjd someoneinjd force-pushed the fold-broadcast-into-elementwise branch from f0d0ceb to 788adaa Compare November 17, 2025 00:03
@someoneinjd
Copy link
Contributor Author

Thank you all for the feedback.

I have updated the PR based on your suggestions. The pattern now handles the case where the elementwise op has an existing non-identity map (like a transpose).

The implementation now uses AffineMap::compose to combine the elementwise op's existing indexing map with the map from the broadcast / transpose op.

I've also added some new test cases that covers this composition logic, similar to the elementwise(transpose(broadcast(x))) example that was discussed.

Could you please take another look when you have a chance? I'd appreciate knowing if this implementation is on the right track.

@someoneinjd someoneinjd force-pushed the fold-broadcast-into-elementwise branch from 788adaa to 6e38277 Compare November 17, 2025 00:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants