11// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{fuse-pad-with-consumers}))" --split-input-file %s | FileCheck %s
22
3- util.func public @fuse_with_consumer (%arg0 : tensor <?x?x?x?xf32 >, %arg1 : index ,
3+ util.func public @fuse_with_consumer_named_op (%arg0 : tensor <?x?x?x?xf32 >, %arg1 : index ,
44 %arg2 : index , %arg3 : index , %arg4 : index ,
55 %arg5 : tensor <?x?x?x?xf32 >, %arg6 : tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
66 %cst = arith.constant 42.0 : f32
@@ -12,7 +12,7 @@ util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
1212 outs (%arg6 : tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 >
1313 util.return %1 : tensor <?x?x?x?xf32 >
1414}
15- // CHECK-LABEL: util.func public @fuse_with_consumer
15+ // CHECK-LABEL: util.func public @fuse_with_consumer_named_op
1616// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
1717// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
1818// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
@@ -23,3 +23,30 @@ util.func public @fuse_with_consumer(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
2323// CHECK-SAME: outs(%[[ARG6]] :
2424// CHECK: flow.return %[[CONV]]
2525// CHECK: util.return %[[RETURN]]
26+
27+ // -----
28+
29+ util.func public @fuse_with_consumer_generalized (%arg0: tensor <?x?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index , %arg4: index , %arg5: tensor <?x?x?x?xf32 >, %arg6: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
30+ %cst = arith.constant 4.200000e+01 : f32
31+ %padded = tensor.pad %arg0 low [0 , 0 , 0 , 0 ] high [%arg1 , %arg2 , %arg3 , %arg4 ] {
32+ ^bb0 (%arg7: index , %arg8: index , %arg9: index , %arg10: index ):
33+ tensor.yield %cst : f32
34+ } : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 >
35+ %0 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 + d4 , d2 + d5 , d6 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d4 , d5 , d6 , d3 )>, affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 , d2 , d3 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " reduction" , " reduction" , " reduction" ]} ins (%padded , %arg5 : tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) outs (%arg6 : tensor <?x?x?x?xf32 >) {
36+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
37+ %1 = arith.mulf %in , %in_0 : f32
38+ %2 = arith.addf %out , %1 : f32
39+ linalg.yield %2 : f32 } -> tensor <?x?x?x?xf32 >
40+ util.return %0 : tensor <?x?x?x?xf32 >
41+ }
42+ // CHECK-LABEL: util.func public @fuse_with_consumer_generalized
43+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
44+ // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
45+ // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
46+ // CHECK: %[[RETURN:.+]] = flow.dispatch.region
47+ // CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]]
48+ // CHECK: %[[CONV:.+]] = linalg.generic
49+ // CHECK-SAME: ins(%[[PADDED]], %[[ARG5]] :
50+ // CHECK-SAME: outs(%[[ARG6]] :
51+ // CHECK: flow.return %[[CONV]]
52+ // CHECK: util.return %[[RETURN]]
0 commit comments