77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
10+ #include " mlir/IR/BuiltinTypeInterfaces.h"
1011#include " llvm/ADT/STLExtras.h"
1112#include " gtest/gtest.h"
1213#include < optional>
@@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
2021 return std::optional<SmallVector<ReassociationIndices>>(list);
2122}
2223
24+ TEST (ReassociationIndicesForCollapse, ScalarTest) {
25+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 }, {}),
26+ makeOptionalIndices ({{0 }}));
27+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 , 1 }, {}),
28+ makeOptionalIndices ({{0 , 1 }}));
29+ EXPECT_EQ (getReassociationIndicesForCollapse ({ShapedType::kDynamic }, {}),
30+ makeOptionalIndices ({{0 }}));
31+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 , ShapedType::kDynamic ,
32+ ShapedType::kDynamic , 1 ,
33+ ShapedType::kDynamic },
34+ {}),
35+ makeOptionalIndices ({{0 , 1 , 2 , 3 , 4 }}));
36+ }
37+
38+ TEST (ReassociationIndicesForCollapse, ScalarTestFailure) {
39+ EXPECT_EQ (getReassociationIndicesForCollapse ({}, {}), std::nullopt );
40+ EXPECT_EQ (getReassociationIndicesForCollapse ({}, {1 }), std::nullopt );
41+ EXPECT_EQ (getReassociationIndicesForCollapse ({2 }, {}), std::nullopt );
42+ EXPECT_EQ (
43+ getReassociationIndicesForCollapse ({1 , 2 , ShapedType::kDynamic , 1 }, {}),
44+ std::nullopt );
45+ }
46+
2347TEST (ReassociationIndicesForCollapse, StaticTest) {
2448 EXPECT_EQ (getReassociationIndicesForCollapse ({10 , 20 }, {200 }),
2549 makeOptionalIndices ({{0 , 1 }}));
0 commit comments