@@ -139,3 +139,46 @@ util.func public @math_sin() {
139139// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
140140// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
141141// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,
142+
143+ // -----
144+
145+ util.func public @use_in_generic (%arg0 : tensor <1 x20 x128 x2 x8 xf32 >) -> tensor <1 x20 x128 x2 x8 xf32 > {
146+ %cst = arith.constant dense_resource <__elided__ > : tensor <128 x2 x8 xf32 >
147+ %cst_0 = arith.constant dense_resource <__elided__ > : tensor <128 x2 x8 xf32 >
148+ %cst_1 = arith.constant 2.500000e-01 : f32
149+ %c0 = arith.constant 0 : index
150+ %c1 = arith.constant 1 : index
151+ %1 = tensor.empty () : tensor <1 x20 x128 x2 x8 xf32 >
152+ %2 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 , d4 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0: tensor <1 x20 x128 x2 x8 xf32 >) outs (%1 : tensor <1 x20 x128 x2 x8 xf32 >) {
153+ ^bb0 (%in: f32 , %out: f32 ):
154+ %6 = arith.mulf %in , %cst_1 : f32
155+ linalg.yield %6 : f32
156+ } -> tensor <1 x20 x128 x2 x8 xf32 >
157+ %3 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d2 , d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d2 , d3 , d4 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d3 , d4 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} ins (%2 , %cst_0 , %cst : tensor <1 x20 x128 x2 x8 xf32 >, tensor <128 x2 x8 xf32 >, tensor <128 x2 x8 xf32 >) outs (%1 : tensor <1 x20 x128 x2 x8 xf32 >) {
158+ ^bb0 (%in: f32 , %in_2: f32 , %in_3: f32 , %out: f32 ):
159+ %6 = linalg.index 0 : index
160+ %7 = linalg.index 1 : index
161+ %8 = linalg.index 2 : index
162+ %9 = linalg.index 3 : index
163+ %10 = linalg.index 4 : index
164+ %11 = affine.apply affine_map <()[s0 , s1 ] -> (s0 + s1 * 20 )>()[%7 , %6 ]
165+ %12 = arith.subi %c1 , %9 : index
166+ %extracted = tensor.extract %2 [%c0 , %11 , %8 , %12 , %10 ] : tensor <1 x20 x128 x2 x8 xf32 >
167+ %13 = arith.negf %extracted : f32
168+ %14 = arith.cmpi eq , %12 , %c1 : index
169+ %15 = arith.select %14 , %13 , %extracted : f32
170+ %16 = arith.mulf %15 , %in_3 : f32
171+ %17 = arith.mulf %in , %in_2 : f32
172+ %18 = arith.addf %17 , %16 : f32
173+ linalg.yield %18 : f32
174+ } -> tensor <1 x20 x128 x2 x8 xf32 >
175+ util.return %3 : tensor <1 x20 x128 x2 x8 xf32 >
176+ }
177+
178+ // These cannot be fused because %2 is an operand of %3 and used in its body.
179+ //
180+ // CHECK-LABEL: util.func public @use_in_generic(
181+ // CHECK: %[[GENERIC0:.+]] = linalg.generic
182+ // CHECK: %[[GENERIC1:.+]] = linalg.generic
183+ // CHECK-SAME: ins(%[[GENERIC0]]
184+ // CHECK: util.return %[[GENERIC1]]
0 commit comments