@@ -714,10 +714,9 @@ func.func @fold_extract_transpose(
714714// CHECK-SAME: %[[A:.*]]: f32
715715// CHECK: return %[[A]] : f32
716716func.func @fold_extract_broadcast_same_input_output_scalar (%a : f32 ,
717- %idx0 : index , idx1 : index ) -> f32 {
717+ %idx0 : index , idx1 : index , %idx2 : index ) -> f32 {
718718 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
719- // The indices don't matter for this folder, so we use mixed indices.
720- %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
719+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
721720 return %r : f32
722721}
723722
@@ -727,10 +726,9 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
727726// CHECK-SAME: %[[A:.*]]: vector<4xf32>
728727// CHECK: return %[[A]] : vector<4xf32>
729728func.func @fold_extract_broadcast_same_input_output_vec (%a : vector <4 xf32 >,
730- %idx0 : index ) -> vector <4 xf32 > {
729+ %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
731730 %b = vector.broadcast %a : vector <4 xf32 > to vector <1 x2 x4 xf32 >
732- // The indices don't matter for this folder, so we use mixed indices.
733- %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
731+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
734732 return %r : vector <4 xf32 >
735733}
736734
@@ -741,10 +739,9 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
741739// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
742740// CHECK: return %[[B]] : f32
743741func.func @fold_extract_broadcast_0dvec_input_scalar_output (%a : vector <f32 >,
744- %idx0 : index , idx1 : index ) -> f32 {
742+ %idx0 : index , % idx1 : index , %idx2 : index ) -> f32 {
745743 %b = vector.broadcast %a : vector <f32 > to vector <1 x2 x4 xf32 >
746- // The indices don't matter for this folder, so we use mixed indices.
747- %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
744+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
748745 return %r : f32
749746}
750747
@@ -756,9 +753,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
756753// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
757754// CHECK: return %[[B]] : vector<4xf32>
758755func.func @fold_extract_broadcast_diff_input_output_vec (%a : vector <2 x4 xf32 >,
759- %idx0 : index , idx1 : index ) -> vector <4 xf32 > {
756+ %idx0 : index , % idx1 : index ) -> vector <4 xf32 > {
760757 %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x2 x4 xf32 >
761- // The indices don't matter for this folder, so we use mixed indices.
762758 %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
763759 return %r : vector <4 xf32 >
764760}
@@ -779,10 +775,9 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
779775// CHECK-LABEL: fold_extract_splat
780776// CHECK-SAME: %[[A:.*]]: f32
781777// CHECK: return %[[A]] : f32
782- func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index ) -> f32 {
778+ func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
783779 %b = vector.splat %a : vector <1 x2 x4 xf32 >
784- // The indices don't matter for this folder, so we use mixed indices.
785- %r = vector.extract %b [%idx0 , %idx1 , 2 ] : f32 from vector <1 x2 x4 xf32 >
780+ %r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
786781 return %r : f32
787782}
788783
@@ -791,12 +786,11 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
791786// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
792787// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
793788// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
794- // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0 ] : f32 from vector<2x1xf32>
789+ // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], %[[IDX2]] ] : f32 from vector<2x1xf32>
795790// CHECK: return %[[R]] : f32
796791func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
797- %idx : index , idx1 : index , idx2 : index ) -> f32 {
792+ %idx : index , % idx1 : index , % idx2 : index ) -> f32 {
798793 %b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
799- // The indices don't matter for this folder, so we use mixed indices.
800794 %r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
801795 return %r : f32
802796}
@@ -806,11 +800,10 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
806800// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
807801// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
808802// CHECK: return %[[B]] : vector<4xf32>
809- func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , idx0 : index )
803+ func.func @fold_extract_broadcast_to_higher_rank (%a : f32 , % idx0 : index , %idx1 : index )
810804 -> vector <4 xf32 > {
811805 %b = vector.broadcast %a : f32 to vector <1 x2 x4 xf32 >
812- // The indices don't matter for this canonicalizer, so we use mixed indices.
813- %r = vector.extract %b [0 , %idx0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
806+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
814807 return %r : vector <4 xf32 >
815808}
816809
@@ -820,10 +813,9 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index)
820813// CHECK-SAME: %[[A:.*]]: vector<1xf32>
821814// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
822815// CHECK: return %[[R]] : vector<8xf32>
823- func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, idx0 : index )
816+ func.func @fold_extract_broadcast_to_equal_rank (%a : vector <1 xf32 >, % idx0 : index )
824817 -> vector <8 xf32 > {
825818 %b = vector.broadcast %a : vector <1 xf32 > to vector <1 x8 xf32 >
826- // The indices don't matter for this canonicalizer, so we use mixed indices.
827819 %r = vector.extract %b [%idx0 ] : vector <8 xf32 > from vector <1 x8 xf32 >
828820 return %r : vector <8 xf32 >
829821}
0 commit comments