Skip to content

Commit 788adaa

Browse files
committed
[mlir][linalg][elementwise] Fold broadcast into new elementwise
1 parent 95f2728 commit 788adaa

File tree

2 files changed

+198
-11
lines changed

2 files changed

+198
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,54 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
4141
AffineMap map = op.getMatchingIndexingMap(operand);
4242
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
4343

44-
if (!map.isIdentity() || !transposeOp) {
44+
if (!transposeOp) {
4545
// push in original operand and its map.
4646
newIns.push_back(operand->get());
4747
newMaps.push_back(map);
4848
continue;
4949
}
5050
newIns.push_back(transposeOp.getInput());
51-
// push in transposeOp's inverse permutation map.
52-
newMaps.push_back(transposeOp.getMatchingIndexingMap(
53-
transposeOp.getDpsInputOperand(0)));
51+
// push in composed affine map.
52+
newMaps.push_back(
53+
transposeOp.getMatchingIndexingMap(transposeOp.getDpsInputOperand(0))
54+
.compose(map));
55+
changed = true;
56+
}
57+
if (!changed)
58+
return failure();
59+
newMaps.push_back(op.getIndexingMapsArray().back());
60+
61+
rewriter.replaceOpWithNewOp<ElementwiseOp>(
62+
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
63+
rewriter.getAffineMapArrayAttr(newMaps));
64+
return success();
65+
}
66+
};
67+
68+
struct FoldBroadcastPattern : public OpRewritePattern<ElementwiseOp> {
69+
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
70+
71+
LogicalResult matchAndRewrite(ElementwiseOp op,
72+
PatternRewriter &rewriter) const override {
73+
bool changed = false;
74+
SmallVector<Value> newIns;
75+
SmallVector<AffineMap> newMaps;
76+
for (OpOperand *operand : op.getDpsInputOperands()) {
77+
AffineMap map = op.getMatchingIndexingMap(operand);
78+
auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();
79+
80+
if (!broadcastOp) {
81+
// push in original operand and its map.
82+
newIns.push_back(operand->get());
83+
newMaps.push_back(map);
84+
continue;
85+
}
86+
87+
newIns.push_back(broadcastOp.getInput());
88+
// push in composed affine map.
89+
newMaps.push_back(
90+
broadcastOp.getMatchingIndexingMap(broadcastOp.getDpsInputOperand(0))
91+
.compose(map));
5492
changed = true;
5593
}
5694
if (!changed)
@@ -84,4 +122,5 @@ struct LinalgFoldIntoElementwisePass
84122
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
85123
RewritePatternSet &patterns) {
86124
patterns.add<FoldTransposePattern>(patterns.getContext());
125+
patterns.add<FoldBroadcastPattern>(patterns.getContext());
87126
}

mlir/test/Dialect/Linalg/elementwise/fold.mlir

Lines changed: 155 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
1010
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
1111
//
12-
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
12+
func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
1313
%empty = tensor.empty() : tensor<8x16x32xf32>
14-
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
14+
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
1515
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
16-
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
16+
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
1717
return %result : tensor<8x16x32xf32>
1818
}
1919

@@ -28,16 +28,164 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
2828
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
2929
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
3030
//
31-
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
31+
func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
3232
%c0 = arith.constant 0 : index
3333
%c1 = arith.constant 1 : index
3434
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
3535
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
3636

3737
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
38-
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
38+
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
3939
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
40-
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41-
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
40+
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41+
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
4242
return %result : tensor<?x?xf32>
4343
}
44+
45+
// -----
46+
47+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
49+
//
50+
// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
51+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
52+
// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
53+
// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
54+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
55+
//
56+
func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
57+
%empty = tensor.empty() : tensor<8x16x32xf32>
58+
%broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
59+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
60+
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
61+
return %result : tensor<8x16x32xf32>
62+
}
63+
64+
// -----
65+
66+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
67+
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
68+
//
69+
// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
70+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
71+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
72+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
73+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
74+
//
75+
func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
76+
%c0 = arith.constant 0 : index
77+
%c1 = arith.constant 1 : index
78+
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
79+
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
80+
81+
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
82+
%broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
83+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
84+
ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
85+
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
86+
return %result : tensor<?x?xf32>
87+
}
88+
89+
// -----
90+
91+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
92+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
93+
//
94+
// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
95+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
96+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
97+
// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
98+
// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
99+
//
100+
func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
101+
%empty_b = tensor.empty() : tensor<32x16xf32>
102+
%broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]
103+
104+
%empty_t = tensor.empty() : tensor<16x32xf32>
105+
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
106+
107+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
108+
ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
109+
return %result : tensor<16x32xf32>
110+
}
111+
112+
// -----
113+
114+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
115+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
116+
//
117+
// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
118+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
119+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
120+
// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
121+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
122+
//
123+
func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
124+
%empty_t = tensor.empty() : tensor<16x32xf32>
125+
%transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]
126+
127+
%empty_b = tensor.empty() : tensor<8x16x32xf32>
128+
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<16x32xf32>) outs(%empty_b : tensor<8x16x32xf32>) dimensions = [0]
129+
130+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
131+
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
132+
return %result : tensor<8x16x32xf32>
133+
}
134+
135+
// -----
136+
137+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
138+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
139+
//
140+
// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
141+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
142+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
143+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
144+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
145+
//
146+
func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
147+
%c0 = arith.constant 0 : index
148+
%c1 = arith.constant 1 : index
149+
%dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
150+
%dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>
151+
152+
%empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
153+
%broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]
154+
155+
%empty_t = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
156+
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
157+
158+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
159+
ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
160+
return %result : tensor<?x?xf32>
161+
}
162+
163+
// -----
164+
165+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
166+
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
167+
//
168+
// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
169+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
170+
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
171+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
172+
// CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
173+
//
174+
func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
175+
%c0 = arith.constant 0 : index
176+
%c1 = arith.constant 1 : index
177+
%c2 = arith.constant 2 : index
178+
%dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
179+
%dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
180+
%dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
181+
182+
%empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
183+
%transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]
184+
185+
%empty_b = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
186+
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<?x?xf32>) outs(%empty_b : tensor<?x?x?xf32>) dimensions = [0]
187+
188+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
189+
ins(%broadcasted_A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
190+
return %result : tensor<?x?x?xf32>
191+
}

0 commit comments

Comments
 (0)