-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][linalg][elementwise] Fold broadcast into new elementwise #167626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: someoneinjd (someoneinjd) ChangesFold broadcast into new elementwise Op which has affine-map attached. Full diff: https://github.com/llvm/llvm-project/pull/167626.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
index b1c0c3b161b20..0acebabafa594 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
@@ -64,6 +64,41 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
}
};
+struct FoldBroadcastPattern : public OpRewritePattern<ElementwiseOp> {
+ using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ElementwiseOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ SmallVector<Value> newIns;
+ SmallVector<AffineMap> newMaps;
+ for (OpOperand *operand : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(operand);
+ auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();
+
+ if (!map.isIdentity() || !broadcastOp) {
+ // push in original operand and its map.
+ newIns.push_back(operand->get());
+ newMaps.push_back(map);
+ continue;
+ }
+ newIns.push_back(broadcastOp.getInput());
+ // push in broadcastOp's broadcast map.
+ newMaps.push_back(broadcastOp.getMatchingIndexingMap(
+ broadcastOp.getDpsInputOperand(0)));
+ changed = true;
+ }
+ if (!changed)
+ return failure();
+ newMaps.push_back(op.getIndexingMapsArray().back());
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(
+ op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+ rewriter.getAffineMapArrayAttr(newMaps));
+ return success();
+ }
+};
+
struct LinalgFoldIntoElementwisePass
: public impl::LinalgFoldIntoElementwisePassBase<
LinalgFoldIntoElementwisePass> {
@@ -84,4 +119,5 @@ struct LinalgFoldIntoElementwisePass
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
RewritePatternSet &patterns) {
patterns.add<FoldTransposePattern>(patterns.getContext());
+ patterns.add<FoldBroadcastPattern>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
index e83c32fb6a2cf..a967629e2cf29 100644
--- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir
@@ -41,3 +41,47 @@ func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tens
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//
+// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
+//
+func.func @unary_broadcasted(%A : tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %empty = tensor.empty() : tensor<8x16x32xf32>
+ %broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %result : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
+//
+func.func @binary_broadcasted(%A : tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
+
+ %empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
+ %broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
+ %result = linalg.elementwise kind=#linalg.elementwise_kind<add>
+ ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
|
| AffineMap map = op.getMatchingIndexingMap(operand); | ||
| auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>(); | ||
|
|
||
| if (!map.isIdentity() || !broadcastOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm - also in case the existing map is not the identity map, shouldn't it still compose with the broadcast map? So we could obtain the new map via composition.
Consider the case of elementwise(transpose(broadcast)) whereupon we fold the transpose into the elementwise. This pattern then won't fold in the broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this only handles the identity map case.
I was modeling this new pattern directly after the existing FoldTransposePattern, which also seems to have this same limitation (it doesn't compose with non-identity maps).
I agree we should compose the maps as you suggested. Should I add that map composition logic here in this PR, or would you prefer a new, separate PR to add this functionality for both the new FoldBroadcastPattern and the existing FoldTransposePattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc.
I see two options:
- Simple: Two patterns (trans/bcast) that apply only to the immediate producer and elides it. Recurrently running this pass would take care of arbitrarily complex chains, but it would have to compose the new map with the existing (non-ID) map.
- Complex: Multiple patterns that simplify chains of layout manipulation before merging with the op (elemwise/contract/generic), and then a simple apply to the argument directly (replace map with
affine.applymap).
To me, the first option is much simpler, and iterative. It might not be as complete as the second, but we'll rarely see such complex cases that would need a catch all mega pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @rengolin "This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc."
transpose+broadcast + existing map should be seen as projected-permutation. Let me know if you want to go down that path -- I did something similar here -- https://github.com/llvm/llvm-project/blame/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp.
@MaheshRavishankar had pointed me in this direction originally.
fd9b2eb to
f0d0ceb
Compare
| AffineMap map = op.getMatchingIndexingMap(operand); | ||
| auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>(); | ||
|
|
||
| if (!map.isIdentity() || !broadcastOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @rengolin "This becomes an arbitrarily complex producer chain to analyse. We'd need a reasonable algebra here to simplify pairs and propagate transitive and associative properties, reversibility, etc."
transpose+broadcast + existing map should be seen as projected-permutation. Let me know if you want to go down that path -- I did something similar here -- https://github.com/llvm/llvm-project/blame/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp.
@MaheshRavishankar had pointed me in this direction originally.
f0d0ceb to
788adaa
Compare
|
Thank you all for the feedback. I have updated the PR based on your suggestions. The pattern now handles the case where the elementwise op has an existing non-identity map (like a transpose). The implementation now uses I've also added some new test cases that covers this composition logic, similar to the Could you please take another look when you have a chance? I'd appreciate knowing if this implementation is on the right track. |
788adaa to
6e38277
Compare
Fold broadcast into new elementwise Op which has affine-map attached.