@@ -652,24 +652,44 @@ func.func @fold_extract_transpose(
652652
653653// -----
654654
655- // CHECK-LABEL: fold_extract_broadcast
655+ // CHECK-LABEL: fold_extract_broadcast_same_type
656656// CHECK-SAME: %[[A:.*]]: f32
657657// CHECK: return %[[A]] : f32
658- func.func @fold_extract_broadcast (%a : f32 ) -> f32 {
658+ func.func @fold_extract_broadcast_same_type (%a : f32 ,
659+ %idx0 : index ,
660+ %idx1 : index ) -> f32 {
659661 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
660- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
662+ // The indices don't batter for this folder, so we use mixed indices.
663+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
661664 return %r : f32
662665}
663666
664667// -----
665668
666- // CHECK-LABEL: fold_extract_broadcast_0dvec
669+ // CHECK-LABEL: fold_extract_broadcast_same_type_vec
670+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>
671+ // CHECK: return %[[A]] : vector<4xf32>
672+ func.func @fold_extract_broadcast_same_type_vec (%a : vector <4 xf32 >,
673+ %idx0 : index )
674+ -> vector <4 xf32 > {
675+ %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
676+ // The indices don't batter for this folder, so we use mixed indices.
677+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
678+ return %r : vector <4 xf32 >
679+ }
680+
681+ // -----
682+
683+ // CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
667684// CHECK-SAME: %[[A:.*]]: vector<f32>
668685// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
669686// CHECK: return %[[B]] : f32
670- func.func @fold_extract_broadcast_0dvec (%a : vector <f32 >) -> f32 {
687+ func.func @fold_extract_broadcast_0dvec_and_scalar (%a : vector <f32 >,
688+ %idx0 : index ,
689+ %idx1 : index ) -> f32 {
671690 %b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
672- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
691+ // The indices don't batter for this folder, so we use mixed indices.
692+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
673693 return %r : f32
674694}
675695
@@ -689,57 +709,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
689709// CHECK-LABEL: fold_extract_splat
690710// CHECK-SAME: %[[A:.*]]: f32
691711// CHECK: return %[[A]] : f32
692- func.func @fold_extract_splat (%a : f32 ) -> f32 {
712+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index ) -> f32 {
693713 %b = vector.splat %a : vector <1 x2 x4 xf32 >
694- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
714+ // The indices don't batter for this folder, so we use mixed indices.
715+ %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
695716 return %r : f32
696717}
697718
698719// -----
699720
700- // CHECK-LABEL: fold_extract_broadcast_vector
721+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
701722// CHECK-SAME: %[[A:.*]]: vector<4xf32>
702- // CHECK: return %[[A]] : vector<4xf32>
703- func.func @fold_extract_broadcast_vector (%a : vector <4 xf32 >) -> vector <4 xf32 > {
723+ // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
724+ // CHECK: return %[[R]] : f32
725+ func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <4 xf32 >) -> f32 {
704726 %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
705- %r = vector.extract %b [0 , 1 ] : vector < 4 x f32 > from vector <1 x2 x4 xf32 >
706- return %r : vector < 4 x f32 >
727+ %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
728+ return %r : f32
707729}
708730
709731// -----
710732
711- // CHECK-LABEL: fold_extract_broadcast
733+ // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
712734// CHECK-SAME: %[[A:.*]]: vector<4xf32>
713- // CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
735+ // CHECK-SAME: %[[IDX:.*]]: index
736+ // CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
737+ // CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
714738// CHECK: return %[[R]] : f32
715- func.func @fold_extract_broadcast (%a : vector <4 xf32 >) -> f32 {
739+ // This folder is not yet implemented. Check that this does not fold.
740+ func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi (
741+ %a : vector <4 xf32 >,
742+ %idx : index ) -> f32 {
716743 %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
717- %r = vector.extract %b [0 , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
744+ %r = vector.extract %b [%idx , 1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
718745 return %r : f32
719746}
720747
721748// -----
722749
723- // CHECK-LABEL: fold_extract_broadcast
750+ // CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
724751// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
725752// CHECK: return %[[B]] : vector<4xf32>
726- func.func @fold_extract_broadcast (%a : f32 ) -> vector <4 xf32 > {
753+ func.func @canonicalize_extract_broadcast_to_higher_rank (%a : f32 ,
754+ %idx0 : index )
755+ -> vector <4 xf32 > {
727756 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
728- %r = vector.extract %b [0 , 1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
757+ // The indices don't batter for this canonicalizer, so we use mixed indices.
758+ %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
729759 return %r : vector <4 xf32 >
730760}
731761
732762// -----
733763
734- // CHECK-LABEL: fold_extract_broadcast
764+ // CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
735765// CHECK-SAME: %[[A:.*]]: vector<1xf32>
736766// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
737767// CHECK: return %[[R]] : vector<8xf32>
738- func.func @fold_extract_broadcast (%a : vector <1 xf32 >) -> vector <8 xf32 > {
768+ func.func @canonicalize_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >,
769+ %idx0 : index )
770+ -> vector <8 xf32 > {
739771 %b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
740- %r = vector.extract %b [0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
772+ // The indices don't batter for this canonicalizer, so we use mixed indices.
773+ %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
741774 return %r : vector <8 xf32 >
742775}
776+
743777// -----
744778
745779// CHECK-LABEL: @fold_extract_shuffle
0 commit comments