@@ -710,24 +710,38 @@ func.func @fold_extract_transpose(
710710
711711// -----
712712
713- // CHECK-LABEL: fold_extract_broadcast
713+ // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
714714// CHECK-SAME: %[[A:.*]]: f32
715715// CHECK: return %[[A]] : f32
716- func.func @fold_extract_broadcast (%a : f32 ) -> f32 {
716+ func.func @fold_extract_broadcast_same_input_output_scalar (%a : f32 ,
717+ %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
717718 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
718- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
719+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
719720 return %r : f32
720721}
721722
722723// -----
723724
724- // CHECK-LABEL: fold_extract_broadcast_0dvec
725+ // CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
726+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>
727+ // CHECK: return %[[A]] : vector<4xf32>
728+ func.func @fold_extract_broadcast_same_input_output_vec (%a : vector <4 xf32 >,
729+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
730+ %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
731+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
732+ return %r : vector <4 xf32 >
733+ }
734+
735+ // -----
736+
737+ // CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
725738// CHECK-SAME: %[[A:.*]]: vector<f32>
726739// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727740// CHECK: return %[[B]] : f32
728- func.func @fold_extract_broadcast_0dvec (%a : vector <f32 >) -> f32 {
741+ func.func @fold_extract_broadcast_0dvec_input_scalar_output (%a : vector <f32 >,
742+ %idx0 : index , %idx1 : index , %idx2: index ) -> f32 {
729743 %b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
730- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
744+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
731745 return %r : f32
732746}
733747
@@ -747,57 +761,68 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747761// CHECK-LABEL: fold_extract_splat
748762// CHECK-SAME: %[[A:.*]]: f32
749763// CHECK: return %[[A]] : f32
750- func.func @fold_extract_splat (%a : f32 ) -> f32 {
764+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
751765 %b = vector.splat %a : vector <1 x2 x4 xf32 >
752- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
766+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
753767 return %r : f32
754768}
755769
756770// -----
757771
758- // CHECK-LABEL: fold_extract_broadcast_vector
759- // CHECK-SAME: %[[A:.*]]: vector<4xf32>
760- // CHECK: return %[[A]] : vector<4xf32>
761- func.func @fold_extract_broadcast_vector (%a : vector <4 xf32 >) -> vector <4 xf32 > {
762- %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
763- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
764- return %r : vector <4 xf32 >
772+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
773+ // CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
774+ // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
775+ // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
776+ // CHECK: return %[[R]] : f32
777+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
778+ %idx : index , %idx1 : index , %idx2 : index ) -> f32 {
779+ %b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
780+ %r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
781+ return %r : f32
765782}
766783
767784// -----
768785
769- // CHECK-LABEL: fold_extract_broadcast
770- // CHECK-SAME: %[[A:.*]]: vector<4xf32>
771- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
772- // CHECK: return %[[R]] : f32
773- func.func @fold_extract_broadcast (%a : vector <4 xf32 >) -> f32 {
774- %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
775- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
776- return %r : f32
786+ // CHECK-LABEL: fold_extract_broadcast_to_lower_rank
787+ // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
788+ // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
789+ // CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
790+ // CHECK: return %[[B]] : vector<4xf32>
791+ // rank(extract_output) < rank(broadcast_input)
792+ func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
793+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
794+ %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
795+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
796+ return %r : vector <4 xf32 >
777797}
778798
779799// -----
780800
781- // CHECK-LABEL: fold_extract_broadcast
801+ // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
782802// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783803// CHECK: return %[[B]] : vector<4xf32>
784- func.func @fold_extract_broadcast (%a : f32 ) -> vector <4 xf32 > {
804+ // rank(extract_output) > rank(broadcast_input)
805+ func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , %idx0 : index , %idx1 : index )
806+ -> vector <4 xf32 > {
785807 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
786- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
808+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
787809 return %r : vector <4 xf32 >
788810}
789811
790812// -----
791813
792- // CHECK-LABEL: fold_extract_broadcast
814+ // CHECK-LABEL: fold_extract_broadcast_to_equal_rank
793815// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794816// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795817// CHECK: return %[[R]] : vector<8xf32>
796- func.func @fold_extract_broadcast (%a : vector <1 xf32 >) -> vector <8 xf32 > {
818+ // rank(extract_output) == rank(broadcast_input)
819+ func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, %idx0 : index )
820+ -> vector <8 xf32 > {
797821 %b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
798- %r = vector.extract %b [0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
822+ %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
799823 return %r : vector <8 xf32 >
800824}
825+
801826// -----
802827
803828// CHECK-LABEL: @fold_extract_shuffle
0 commit comments