@@ -115,21 +115,25 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
115115
116116// -----
117117
118- func.func @elementwise_ops (%in1: tensor <8 xf32 >, %in2: tensor <8 xf32 >) -> tensor <8 xf32 > {
119- %fill = tensor.empty () : tensor <8 xf32 >
118+ #identity = affine_map <(d0 , d1 ) -> (d0 , d1 )>
119+ #bcast = affine_map <(d0 , d1 ) -> (d0 )>
120+ func.func @elementwise_ops (%in1: tensor <8 xf32 >, %in2: tensor <8 x10 xf32 >) -> tensor <8 x10 xf32 > {
121+ %fill = tensor.empty () : tensor <8 x10 xf32 >
120122 %add = linalg.elementwise
121123 kind =#linalg.elementwise_kind <add >
122- ins (%in1 , %in2: tensor <8 xf32 >, tensor <8 xf32 >) outs (%fill: tensor <8 xf32 >) -> tensor <8 xf32 >
123- %wqrt = linalg.elementwise
124+ indexing_maps = [#bcast , #identity , #identity ]
125+ ins (%in1 , %in2: tensor <8 xf32 >, tensor <8 x10 xf32 >) outs (%fill: tensor <8 x10 xf32 >) -> tensor <8 x10 xf32 >
126+ %sqrt = linalg.elementwise
124127 kind =#linalg.elementwise_kind <sqrt >
125- ins (%add : tensor <8 xf32 >) outs (%fill : tensor <8 xf32 >) -> tensor <8 xf32 >
126- return %wqrt : tensor <8 xf32 >
128+ indexing_maps = [#identity , #identity ]
129+ ins (%add : tensor <8 x10 xf32 >) outs (%fill : tensor <8 x10 xf32 >) -> tensor <8 x10 xf32 >
130+ return %sqrt : tensor <8 x10 xf32 >
127131}
128132
129133// CHECK-LABEL: func @elementwise_ops
130134// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
131- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32 >
132- // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32 >
135+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32 >
136+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32 >
133137// CHECK: %[[FUSED_OP:.+]] = linalg.generic
134138// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
135139// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
0 commit comments