@@ -141,3 +141,66 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
141141// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
142142// CHECK-NEXT: linalg.yield %[[SQRT]]
143143// CHECK-NOT: linalg.map
144+
145+ // -----
146+
147+ func.func @map_multi_ops (%arg0: tensor <8 xf32 >, %arg1: tensor <8 xf32 >, %arg2: tensor <8 xf32 >) -> tensor <8 xf32 > {
148+ %fill = tensor.empty () : tensor <8 xf32 >
149+ %add_exp = linalg.map ins (%arg0 , %arg1: tensor <8 xf32 >, tensor <8 xf32 >) outs (%fill: tensor <8 xf32 >)
150+ (%in0 : f32 , %in1 : f32 ) {
151+ %add = arith.addf %in0 , %in1 : f32
152+ %exp = math.exp %add : f32
153+ linalg.yield %exp : f32
154+ }
155+ %sqrt_mul = linalg.map ins (%add_exp , %arg2 : tensor <8 xf32 >, tensor <8 xf32 >) outs (%fill : tensor <8 xf32 >)
156+ (%in0 : f32 , %in1 : f32 ) {
157+ %sqrt = math.sqrt %in0 : f32
158+ %mul = arith.mulf %sqrt , %in1 : f32
159+ linalg.yield %mul : f32
160+ }
161+ return %sqrt_mul : tensor <8 xf32 >
162+ }
163+
164+ // CHECK-LABEL: func @map_multi_ops
165+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
166+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
167+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32>
168+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
169+ // CHECK: %[[FUSED_OP:.+]] = linalg.generic
170+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] :
171+ // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
172+ // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
173+ // CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]]
174+ // CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]]
175+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]]
176+ // CHECK-NEXT: linalg.yield %[[MUL]]
177+ // CHECK-NOT: linalg.map
178+
179+ // -----
180+
181+ #identity = affine_map <(d0 , d1 ) -> (d0 , d1 )>
182+ #bcast = affine_map <(d0 , d1 ) -> (d0 )>
183+ func.func @map_genric_ops (%arg0: tensor <8 xf32 >, %arg1: tensor <8 x10 xf32 >) -> tensor <8 x10 xf32 > {
184+ %fill = tensor.empty () : tensor <8 x10 xf32 >
185+ %add = linalg.generic
186+ {index ing_maps = [#bcast , #identity , #identity ], iterator_types = [" parallel" , " parallel" ]}
187+ ins (%arg0 , %arg1: tensor <8 xf32 >, tensor <8 x10 xf32 >) outs (%fill: tensor <8 x10 xf32 >) {
188+ ^bb0 (%in0: f32 , %in1: f32 , %out: f32 ):
189+ %add = arith.addf %in0 , %in1 : f32
190+ linalg.yield %add : f32
191+ } -> tensor <8 x10 xf32 >
192+ %sqrt = linalg.map { math.sqrt } ins (%add : tensor <8 x10 xf32 >) outs (%fill : tensor <8 x10 xf32 >)
193+ return %sqrt : tensor <8 x10 xf32 >
194+ }
195+
196+ // CHECK-LABEL: func @map_genric_ops
197+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
198+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
199+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
200+ // CHECK: %[[FUSED_OP:.+]] = linalg.generic
201+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
202+ // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
203+ // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
204+ // CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
205+ // CHECK-NEXT: linalg.yield %[[SQRT]]
206+ // CHECK-NOT: linalg.map
0 commit comments