@@ -710,10 +710,10 @@ func.func @fold_extract_transpose(
710710
711711// -----
712712
713- // CHECK-LABEL: fold_extract_broadcast_same_input_output
713+ // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
714714// CHECK-SAME: %[[A:.*]]: f32
715715// CHECK: return %[[A]] : f32
716- func.func @fold_extract_broadcast_same_input_output (%a : f32 ,
716+ func.func @fold_extract_broadcast_same_input_output_vec (%a : f32 ,
717717 %idx0 : index ,
718718 %idx1 : index ) -> f32 {
719719 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
@@ -752,6 +752,22 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
752752
753753// -----
754754
755+ // CHECK-LABEL: fold_extract_broadcast_diff_input_output_vec
756+ // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
757+ // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
758+ // CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
759+ // CHECK: return %[[B]] : vector<4xf32>
760+ func.func @fold_extract_broadcast_diff_input_output_vec (%a : vector <2 x4 xf32 >,
761+ %idx0 : index ,
762+ %idx1 : index ) -> vector <4 xf32 > {
763+ %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
764+ // The indices don't matter for this folder, so we use mixed indices.
765+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
766+ return %r : vector <4 xf32 >
767+ }
768+
769+ // -----
770+
755771// CHECK-LABEL: fold_extract_broadcast_negative
756772// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
757773// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -776,13 +792,17 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
776792// -----
777793
778794// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
779- // CHECK-SAME: %[[A:.*]]: vector<4xf32>
780- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
795+ // CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
796+ // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
797+ // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
781798// CHECK: return %[[R]] : f32
782- func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <4 xf32 >, %idx : index ) -> f32 {
783- %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
799+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
800+ %idx : index ,
801+ %idx1 : index ,
802+ %idx2 : index ) -> f32 {
803+ %b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
784804 // The indices don't matter for this folder, so we use mixed indices.
785- %r = vector.extract %b [%idx , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
805+ %r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
786806 return %r : f32
787807}
788808
0 commit comments