Skip to content

Commit cf67ab6

Browse files
committed
switch elementwise test to broadcast version
1 parent 58582bf commit cf67ab6

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,25 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
115115

116116
// -----
117117

118-
func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
119-
%fill = tensor.empty() : tensor<8xf32>
118+
#identity = affine_map<(d0, d1) -> (d0, d1)>
119+
#bcast = affine_map<(d0, d1) -> (d0)>
120+
func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> {
121+
%fill = tensor.empty() : tensor<8x10xf32>
120122
%add = linalg.elementwise
121123
kind=#linalg.elementwise_kind<add>
122-
ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32>
123-
%wqrt = linalg.elementwise
124+
indexing_maps = [#bcast, #identity, #identity]
125+
ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32>
126+
%sqrt = linalg.elementwise
124127
kind=#linalg.elementwise_kind<sqrt>
125-
ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32>
126-
return %wqrt : tensor<8xf32>
128+
indexing_maps = [#identity, #identity]
129+
ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32>
130+
return %sqrt : tensor<8x10xf32>
127131
}
128132

129133
// CHECK-LABEL: func @elementwise_ops
130134
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
131-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
132-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
135+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
136+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
133137
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
134138
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
135139
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):

0 commit comments

Comments
 (0)