Skip to content

Commit b1d15b2

Browse files
committed
add requested tests
1 parent 7d402c1 commit b1d15b2

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,66 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
141141
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
142142
// CHECK-NEXT: linalg.yield %[[SQRT]]
143143
// CHECK-NOT: linalg.map
144+
145+
// -----
146+
147+
func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
148+
%fill = tensor.empty() : tensor<8xf32>
149+
%add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
150+
(%in0 : f32, %in1 : f32) {
151+
%add = arith.addf %in0, %in1 : f32
152+
%exp = math.exp %add : f32
153+
linalg.yield %exp : f32
154+
}
155+
%sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>)
156+
(%in0 : f32, %in1 : f32) {
157+
%sqrt = math.sqrt %in0 : f32
158+
%mul = arith.mulf %sqrt, %in1 : f32
159+
linalg.yield %mul : f32
160+
}
161+
return %sqrt_mul : tensor<8xf32>
162+
}
163+
164+
// CHECK-LABEL: func @map_multi_ops
165+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
166+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
167+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32>
168+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
169+
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
170+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] :
171+
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
172+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
173+
// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]]
174+
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]]
175+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]]
176+
// CHECK-NEXT: linalg.yield %[[MUL]]
177+
// CHECK-NOT: linalg.map
178+
179+
// -----
180+
181+
#identity = affine_map<(d0, d1) -> (d0, d1)>
182+
#bcast = affine_map<(d0, d1) -> (d0)>
183+
func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> {
184+
%fill = tensor.empty() : tensor<8x10xf32>
185+
%add = linalg.generic
186+
{indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]}
187+
ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) {
188+
^bb0(%in0: f32, %in1: f32, %out: f32):
189+
%add = arith.addf %in0, %in1 : f32
190+
linalg.yield %add : f32
191+
} -> tensor<8x10xf32>
192+
%sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>)
193+
return %sqrt : tensor<8x10xf32>
194+
}
195+
196+
// CHECK-LABEL: func @map_genric_ops
197+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
198+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
199+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
200+
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
201+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
202+
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
203+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
204+
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
205+
// CHECK-NEXT: linalg.yield %[[SQRT]]
206+
// CHECK-NOT: linalg.map

0 commit comments

Comments
 (0)