@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
244244
245245// -----
246246
247+ // CHECK-LABEL: @extract_elementwise
248+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
249+ func.func @extract_elementwise (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> f32 {
250+ // CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
251+ // CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
252+ // CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
253+ // CHECK: return %[[RES]] : f32
254+ %0 = arith.addf %arg0 , %arg1 : vector <4 xf32 >
255+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
256+ return %1 : f32
257+ }
258+
259+ // -----
260+
261+ // CHECK-LABEL: @extract_vec_elementwise
262+ // CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
263+ func.func @extract_vec_elementwise (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 x4 xf32 >) -> vector <4 xf32 > {
264+ // CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
265+ // CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
266+ // CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
267+ // CHECK: return %[[RES]] : vector<4xf32>
268+ %0 = arith.addf %arg0 , %arg1 : vector <2 x4 xf32 >
269+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
270+ return %1 : vector <4 xf32 >
271+ }
272+
273+ // -----
274+
275+ // CHECK-LABEL: @extract_elementwise_use
276+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
277+ func.func @extract_elementwise_use (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> (f32 , vector <4 xf32 >) {
278+ // Dop not propagate extract, as elementwise has other uses
279+ // CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
280+ // CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
281+ // CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
282+ %0 = arith.addf %arg0 , %arg1 : vector <4 xf32 >
283+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
284+ return %1 , %0 : f32 , vector <4 xf32 >
285+ }
286+
287+ // -----
288+
247289// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
248290func.func @constant_mask_transpose_to_transposed_constant_mask () -> (vector <2 x3 x4 xi1 >, vector <4 x2 x3 xi1 >) {
249291 // CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
0 commit comments