@@ -438,6 +438,17 @@ func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>
438438 return %1 : f32
439439}
440440
441+ // CHECK-LABEL: @extract_elementwise_arg_res_different_types
442+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>)
443+ func.func @extract_elementwise_arg_res_different_types (%arg0: vector <4 xindex >) -> i64 {
444+ // CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
445+ // CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
446+ // CHECK: return %[[RES]] : i64
447+ %0 = arith.index_cast %arg0: vector <4 xindex > to vector <4 xi64 >
448+ %1 = vector.extract %0 [1 ] : i64 from vector <4 xi64 >
449+ return %1 : i64
450+ }
451+
441452// CHECK-LABEL: @extract_elementwise_vec
442453// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
443454func.func @extract_elementwise_vec (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 x4 xf32 >) -> vector <4 xf32 > {
@@ -461,3 +472,27 @@ func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector
461472 %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
462473 return %1 , %0 : f32 , vector <4 xf32 >
463474}
475+
476+ // CHECK-LABEL: @extract_elementwise_not_one_res
477+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
478+ func.func @extract_elementwise_not_one_res (%arg0: vector <4 xi32 >, %arg1: vector <4 xi32 >) -> i32 {
479+ // Do not propagate extract, as elementwise has more than 1 result.
480+ // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
481+ // CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
482+ // CHECK: return %[[EXT]] : i32
483+ %low , %hi = arith.mulsi_extended %arg0 , %arg1 : vector <4 xi32 >
484+ %1 = vector.extract %low [1 ] : i32 from vector <4 xi32 >
485+ return %1 : i32
486+ }
487+
488+ // CHECK-LABEL: @extract_not_elementwise
489+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
490+ func.func @extract_not_elementwise (%arg0: vector <4 xi64 >) -> i64 {
491+ // `test.increment` is not an elemewise op.
492+ // CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
493+ // CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
494+ // CHECK: return %[[RES]] : i64
495+ %0 = test.increment %arg0: vector <4 xi64 >
496+ %1 = vector.extract %0 [1 ] : i64 from vector <4 xi64 >
497+ return %1 : i64
498+ }
0 commit comments