@@ -710,24 +710,44 @@ func.func @fold_extract_transpose(
710710
711711// -----
712712
713- // CHECK-LABEL: fold_extract_broadcast
713+ // CHECK-LABEL: fold_extract_broadcast_same_type
714714// CHECK-SAME: %[[A:.*]]: f32
715715// CHECK: return %[[A]] : f32
716- func.func @fold_extract_broadcast (%a : f32 ) -> f32 {
716+ func.func @fold_extract_broadcast_same_type (%a : f32 ,
717+ %idx0 : index ,
718+ %idx1 : index ) -> f32 {
717719 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
718- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
720+ // The indices don't batter for this folder, so we use mixed indices.
721+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
719722 return %r : f32
720723}
721724
722725// -----
723726
724- // CHECK-LABEL: fold_extract_broadcast_0dvec
727+ // CHECK-LABEL: fold_extract_broadcast_same_type_vec
728+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>
729+ // CHECK: return %[[A]] : vector<4xf32>
730+ func.func @fold_extract_broadcast_same_type_vec (%a : vector <4 xf32 >,
731+ %idx0 : index )
732+ -> vector <4 xf32 > {
733+ %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
734+ // The indices don't batter for this folder, so we use mixed indices.
735+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
736+ return %r : vector <4 xf32 >
737+ }
738+
739+ // -----
740+
741+ // CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
725742// CHECK-SAME: %[[A:.*]]: vector<f32>
726743// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
727744// CHECK: return %[[B]] : f32
728- func.func @fold_extract_broadcast_0dvec (%a : vector <f32 >) -> f32 {
745+ func.func @fold_extract_broadcast_0dvec_and_scalar (%a : vector <f32 >,
746+ %idx0 : index ,
747+ %idx1 : index ) -> f32 {
729748 %b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
730- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
749+ // The indices don't batter for this folder, so we use mixed indices.
750+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
731751 return %r : f32
732752}
733753
@@ -747,57 +767,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
747767// CHECK-LABEL: fold_extract_splat
748768// CHECK-SAME: %[[A:.*]]: f32
749769// CHECK: return %[[A]] : f32
750- func.func @fold_extract_splat (%a : f32 ) -> f32 {
770+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index ) -> f32 {
751771 %b = vector.splat %a : vector <1 x2 x4 xf32 >
752- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
772+ // The indices don't batter for this folder, so we use mixed indices.
773+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
753774 return %r : f32
754775}
755776
756777// -----
757778
758- // CHECK-LABEL: fold_extract_broadcast_vector
779+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
759780// CHECK-SAME: %[[A:.*]]: vector<4xf32>
760- // CHECK: return %[[A]] : vector<4xf32>
761- func.func @fold_extract_broadcast_vector (%a : vector <4 xf32 >) -> vector <4 xf32 > {
781+ // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
782+ // CHECK: return %[[R]] : f32
783+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <4 xf32 >) -> f32 {
762784 %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
763- %r = vector.extract %b [0 , 1 ] : vector < 4 x f32 > from vector <1 x2 x4 xf32 >
764- return %r : vector < 4 x f32 >
785+ %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
786+ return %r : f32
765787}
766788
767789// -----
768790
769- // CHECK-LABEL: fold_extract_broadcast
791+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
770792// CHECK-SAME: %[[A:.*]]: vector<4xf32>
771- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
793+ // CHECK-SAME: %[[IDX:.*]]: index
794+ // CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
795+ // CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
772796// CHECK: return %[[R]] : f32
773- func.func @fold_extract_broadcast (%a : vector <4 xf32 >) -> f32 {
797+ // This folder is not yet implemented. Check that this does not fold.
798+ func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi (
799+ %a : vector <4 xf32 >,
800+ %idx : index ) -> f32 {
774801 %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
775- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
802+ %r = vector.extract %b [%idx , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
776803 return %r : f32
777804}
778805
779806// -----
780807
781- // CHECK-LABEL: fold_extract_broadcast
808+ // CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
782809// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
783810// CHECK: return %[[B]] : vector<4xf32>
784- func.func @fold_extract_broadcast (%a : f32 ) -> vector <4 xf32 > {
811+ func.func @canonicalize_extract_broadcast_to_higher_rank (%a : f32 ,
812+ %idx0 : index )
813+ -> vector <4 xf32 > {
785814 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
786- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
815+ // The indices don't batter for this canonicalizer, so we use mixed indices.
816+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
787817 return %r : vector <4 xf32 >
788818}
789819
790820// -----
791821
792- // CHECK-LABEL: fold_extract_broadcast
822+ // CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
793823// CHECK-SAME: %[[A:.*]]: vector<1xf32>
794824// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
795825// CHECK: return %[[R]] : vector<8xf32>
796- func.func @fold_extract_broadcast (%a : vector <1 xf32 >) -> vector <8 xf32 > {
826+ func.func @canonicalize_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >,
827+ %idx0 : index )
828+ -> vector <8 xf32 > {
797829 %b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
798- %r = vector.extract %b [0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
830+ // The indices don't batter for this canonicalizer, so we use mixed indices.
831+ %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
799832 return %r : vector <8 xf32 >
800833}
834+
801835// -----
802836
803837// CHECK-LABEL: @fold_extract_shuffle
0 commit comments