@@ -747,20 +747,6 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
747747
748748// -----
749749
750- // CHECK-LABEL: fold_extract_broadcast_diff_input_output_vec
751- // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
752- // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
753- // CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
754- // CHECK: return %[[B]] : vector<4xf32>
755- func.func @fold_extract_broadcast_diff_input_output_vec (%a : vector <2 x4 xf32 >,
756- %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
757- %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
758- %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
759- return %r : vector <4 xf32 >
760- }
761-
762- // -----
763-
764750// CHECK-LABEL: fold_extract_broadcast_negative
765751// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
766752// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -797,9 +783,25 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
797783
798784// -----
799785
786+ // CHECK-LABEL: fold_extract_broadcast_to_lower_rank
787+ // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
788+ // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
789+ // CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
790+ // CHECK: return %[[B]] : vector<4xf32>
791+ // rank(extract_output) < rank(broadcast_input)
792+ func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
793+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
794+ %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
795+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
796+ return %r : vector <4 xf32 >
797+ }
798+
799+ // -----
800+
800801// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
801802// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
802803// CHECK: return %[[B]] : vector<4xf32>
804+ // rank(extract_output) > rank(broadcast_input)
803805func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , %idx0 : index , %idx1 : index )
804806 -> vector <4 xf32 > {
805807 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
@@ -813,6 +815,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
813815// CHECK-SAME: %[[A:.*]]: vector<1xf32>
814816// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
815817// CHECK: return %[[R]] : vector<8xf32>
818+ // rank(extract_output) == rank(broadcast_input)
816819func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, %idx0 : index )
817820 -> vector <8 xf32 > {
818821 %b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
0 commit comments