Skip to content

Commit fd9b2eb

Browse files
committed
[mlir][linalg][elementwise] Fold broadcast into new elementwise
1 parent 95f2728 commit fd9b2eb

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,41 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
6464
}
6565
};
6666

67+
struct FoldBroadcastPattern : public OpRewritePattern<ElementwiseOp> {
68+
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
69+
70+
LogicalResult matchAndRewrite(ElementwiseOp op,
71+
PatternRewriter &rewriter) const override {
72+
bool changed = false;
73+
SmallVector<Value> newIns;
74+
SmallVector<AffineMap> newMaps;
75+
for (OpOperand *operand : op.getDpsInputOperands()) {
76+
AffineMap map = op.getMatchingIndexingMap(operand);
77+
auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();
78+
79+
if (!map.isIdentity() || !broadcastOp) {
80+
// push in original operand and its map.
81+
newIns.push_back(operand->get());
82+
newMaps.push_back(map);
83+
continue;
84+
}
85+
newIns.push_back(broadcastOp.getInput());
86+
// push in broadcastOp's broadcast map.
87+
newMaps.push_back(broadcastOp.getMatchingIndexingMap(
88+
broadcastOp.getDpsInputOperand(0)));
89+
changed = true;
90+
}
91+
if (!changed)
92+
return failure();
93+
newMaps.push_back(op.getIndexingMapsArray().back());
94+
95+
rewriter.replaceOpWithNewOp<ElementwiseOp>(
96+
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
97+
rewriter.getAffineMapArrayAttr(newMaps));
98+
return success();
99+
}
100+
};
101+
67102
struct LinalgFoldIntoElementwisePass
68103
: public impl::LinalgFoldIntoElementwisePassBase<
69104
LinalgFoldIntoElementwisePass> {
@@ -84,4 +119,5 @@ struct LinalgFoldIntoElementwisePass
84119
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
85120
RewritePatternSet &patterns) {
86121
patterns.add<FoldTransposePattern>(patterns.getContext());
122+
patterns.add<FoldBroadcastPattern>(patterns.getContext());
87123
}

mlir/test/Dialect/Linalg/elementwise/fold.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,47 @@ func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tens
4141
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
4242
return %result : tensor<?x?xf32>
4343
}
44+
45+
// -----
46+
47+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
49+
//
50+
// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
51+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
52+
// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
53+
// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
54+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
55+
//
56+
func.func @unary_broadcasted(%A : tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
57+
%empty = tensor.empty() : tensor<8x16x32xf32>
58+
%broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
59+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
60+
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
61+
return %result : tensor<8x16x32xf32>
62+
}
63+
64+
// -----
65+
66+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
67+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
68+
//
69+
// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
70+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
71+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
72+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
73+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
74+
//
75+
func.func @binary_broadcasted(%A : tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
76+
%c0 = arith.constant 0 : index
77+
%c1 = arith.constant 1 : index
78+
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
79+
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
80+
81+
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
82+
%broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
83+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
84+
ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
85+
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
86+
return %result : tensor<?x?xf32>
87+
}

0 commit comments

Comments
 (0)