Skip to content

Commit 459bcb4

Browse files
committed
add checks for nontrivial map cases
1 parent b1d15b2 commit 459bcb4

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,14 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
130130
return %sqrt : tensor<8x10xf32>
131131
}
132132

133+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
134+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
133135
// CHECK-LABEL: func @elementwise_ops
134136
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
135137
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
136138
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
137139
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
140+
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
138141
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
139142
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
140143
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
@@ -193,11 +196,14 @@ func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tens
193196
return %sqrt : tensor<8x10xf32>
194197
}
195198

199+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
200+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
196201
// CHECK-LABEL: func @map_genric_ops
197202
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
198203
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
199204
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
200205
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
206+
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
201207
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
202208
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
203209
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]

mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,14 @@ struct TestLinalgElementwiseFusion
151151
MLIRContext *context = &this->getContext();
152152
func::FuncOp funcOp = this->getOperation();
153153

154+
auto controlFn = [](OpOperand *operand) {
155+
auto owner = cast<linalg::LinalgOp>(operand->getOwner());
156+
auto producer = cast<linalg::LinalgOp>(operand->get().getDefiningOp());
157+
return (linalg::isElementwise(owner) && linalg::isElementwise(producer)) && (!isa<linalg::BroadcastOp>(producer) && !isa<linalg::BroadcastOp>(owner));
158+
};
154159
if (fuseGenericOps) {
155160
RewritePatternSet fusionPatterns(context);
156-
auto controlFn = [](OpOperand *operand) { return true; };
161+
// auto controlFn = [](OpOperand *operand) { return true; };
157162
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
158163
if (failed(applyPatternsGreedily(funcOp.getBody(),
159164
std::move(fusionPatterns))))

0 commit comments

Comments
 (0)