Skip to content

Commit 7900062

Browse files
committed
fix types and use dynamic indices
1 parent e2a4e6d commit 7900062

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,9 @@ func.func @fold_extract_transpose(
714714
// CHECK-SAME: %[[A:.*]]: f32
715715
// CHECK: return %[[A]] : f32
716716
func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
717-
%idx0 : index, idx1 : index) -> f32 {
717+
%idx0 : index, idx1 : index, %idx2 : index) -> f32 {
718718
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
719-
// The indices don't matter for this folder, so we use mixed indices.
720-
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
719+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
721720
return %r : f32
722721
}
723722

@@ -727,10 +726,9 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
727726
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
728727
// CHECK: return %[[A]] : vector<4xf32>
729728
func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
730-
%idx0 : index) -> vector<4xf32> {
729+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
731730
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
732-
// The indices don't matter for this folder, so we use mixed indices.
733-
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
731+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
734732
return %r : vector<4xf32>
735733
}
736734

@@ -741,10 +739,9 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
741739
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
742740
// CHECK: return %[[B]] : f32
743741
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
744-
%idx0 : index, idx1 : index) -> f32 {
742+
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
745743
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
746-
// The indices don't matter for this folder, so we use mixed indices.
747-
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
744+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
748745
return %r : f32
749746
}
750747

@@ -756,9 +753,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
756753
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
757754
// CHECK: return %[[B]] : vector<4xf32>
758755
func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>,
759-
%idx0 : index, idx1 : index) -> vector<4xf32> {
756+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
760757
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
761-
// The indices don't matter for this folder, so we use mixed indices.
762758
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
763759
return %r : vector<4xf32>
764760
}
@@ -779,10 +775,9 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
779775
// CHECK-LABEL: fold_extract_splat
780776
// CHECK-SAME: %[[A:.*]]: f32
781777
// CHECK: return %[[A]] : f32
782-
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
778+
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
783779
%b = vector.splat %a : vector<1x2x4xf32>
784-
// The indices don't matter for this folder, so we use mixed indices.
785-
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
780+
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
786781
return %r : f32
787782
}
788783

@@ -791,12 +786,11 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
791786
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
792787
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
793788
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
794-
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
789+
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], %[[IDX2]]] : f32 from vector<2x1xf32>
795790
// CHECK: return %[[R]] : f32
796791
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
797-
%idx : index, idx1 : index, idx2 : index) -> f32 {
792+
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
798793
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
799-
// The indices don't matter for this folder, so we use mixed indices.
800794
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
801795
return %r : f32
802796
}
@@ -806,11 +800,10 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
806800
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
807801
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
808802
// CHECK: return %[[B]] : vector<4xf32>
809-
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index)
803+
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
810804
-> vector<4xf32> {
811805
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
812-
// The indices don't matter for this canonicalizer, so we use mixed indices.
813-
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
806+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
814807
return %r : vector<4xf32>
815808
}
816809

@@ -820,10 +813,9 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index)
820813
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
821814
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
822815
// CHECK: return %[[R]] : vector<8xf32>
823-
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, idx0 : index)
816+
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
824817
-> vector<8xf32> {
825818
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
826-
// The indices don't matter for this canonicalizer, so we use mixed indices.
827819
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
828820
return %r : vector<8xf32>
829821
}

0 commit comments

Comments
 (0)