Skip to content

Commit 350f449

Browse files
committed
resolve comment
1 parent 110ccdc commit 350f449

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -747,20 +747,6 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
747747

748748
// -----
749749

750-
// CHECK-LABEL: fold_extract_broadcast_diff_input_output_vec
751-
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
752-
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
753-
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
754-
// CHECK: return %[[B]] : vector<4xf32>
755-
func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>,
756-
%idx0 : index, %idx1 : index) -> vector<4xf32> {
757-
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
758-
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
759-
return %r : vector<4xf32>
760-
}
761-
762-
// -----
763-
764750
// CHECK-LABEL: fold_extract_broadcast_negative
765751
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
766752
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -797,9 +783,25 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
797783

798784
// -----
799785

786+
// CHECK-LABEL: fold_extract_broadcast_to_lower_rank
787+
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
788+
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
789+
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
790+
// CHECK: return %[[B]] : vector<4xf32>
791+
// rank(extract_output) < rank(broadcast_input)
792+
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
793+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
794+
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
795+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
796+
return %r : vector<4xf32>
797+
}
798+
799+
// -----
800+
800801
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
801802
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
802803
// CHECK: return %[[B]] : vector<4xf32>
804+
// rank(extract_output) > rank(broadcast_input)
803805
func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
804806
-> vector<4xf32> {
805807
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
@@ -813,6 +815,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
813815
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
814816
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
815817
// CHECK: return %[[R]] : vector<8xf32>
818+
// rank(extract_output) == rank(broadcast_input)
816819
func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
817820
-> vector<8xf32> {
818821
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>

0 commit comments

Comments
 (0)