Skip to content

Commit 7f9b6ae

Browse files
IanWood1keshavvinayak01
authored andcommitted
[Dispatch Creation] Fuse pad with generic conv consumer (iree-org#21606)
Use `linalg::isaConvolutionOpInterface` instead of `isa<linalg::ConvolutionOpInterface>` to detect generic convolutions. Signed-off-by: Ian Wood <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 08fa1e0 commit 7f9b6ae

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,9 @@ static bool isFusableWithProducer(
768768
return false;
769769
}
770770

771+
auto linalgConsumer = dyn_cast<linalg::LinalgOp>(consumer);
771772
if (options.fusePadWithConsumers && isa<tensor::PadOp>(producer) &&
772-
isa<linalg::ConvolutionOpInterface>(consumer)) {
773+
linalgConsumer && linalg::isaConvolutionOpInterface(linalgConsumer)) {
773774
return true;
774775
}
775776

compiler/src/iree/compiler/DispatchCreation/test/pad_fusion_with_consumer.mlir

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{fuse-pad-with-consumers}))" --split-input-file %s | FileCheck %s
22

3-
util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
3+
util.func public @fuse_with_consumer_named_op(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
44
%arg2 : index, %arg3 : index, %arg4 : index,
55
%arg5 : tensor<?x?x?x?xf32>, %arg6 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
66
%cst = arith.constant 42.0 : f32
@@ -12,7 +12,7 @@ util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
1212
outs(%arg6 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
1313
util.return %1 : tensor<?x?x?x?xf32>
1414
}
15-
// CHECK-LABEL: util.func public @fuse_with_consumer
15+
// CHECK-LABEL: util.func public @fuse_with_consumer_named_op
1616
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
1717
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
1818
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
@@ -23,3 +23,30 @@ util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
2323
// CHECK-SAME: outs(%[[ARG6]] :
2424
// CHECK: flow.return %[[CONV]]
2525
// CHECK: util.return %[[RETURN]]
26+
27+
// -----
28+
29+
util.func public @fuse_with_consumer_generalized(%arg0: tensor<?x?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: tensor<?x?x?x?xf32>, %arg6: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
30+
%cst = arith.constant 4.200000e+01 : f32
31+
%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[%arg1, %arg2, %arg3, %arg4] {
32+
^bb0(%arg7: index, %arg8: index, %arg9: index, %arg10: index):
33+
tensor.yield %cst : f32
34+
} : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
35+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%padded, %arg5 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%arg6 : tensor<?x?x?x?xf32>) {
36+
^bb0(%in: f32, %in_0: f32, %out: f32):
37+
%1 = arith.mulf %in, %in_0 : f32
38+
%2 = arith.addf %out, %1 : f32
39+
linalg.yield %2 : f32 } -> tensor<?x?x?x?xf32>
40+
util.return %0 : tensor<?x?x?x?xf32>
41+
}
42+
// CHECK-LABEL: util.func public @fuse_with_consumer_generalized
43+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
44+
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
45+
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
46+
// CHECK: %[[RETURN:.+]] = flow.dispatch.region
47+
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]]
48+
// CHECK: %[[CONV:.+]] = linalg.generic
49+
// CHECK-SAME: ins(%[[PADDED]], %[[ARG5]] :
50+
// CHECK-SAME: outs(%[[ARG6]] :
51+
// CHECK: flow.return %[[CONV]]
52+
// CHECK: util.return %[[RETURN]]

0 commit comments

Comments
 (0)