Skip to content

Commit b5425a0

Browse files
[fixup] Reduce the number of test cases
1 parent 145b055 commit b5425a0

File tree

1 file changed

+45
-91
lines changed

1 file changed

+45
-91
lines changed

mlir/unittests/IR/MemrefLayoutTest.cpp

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -28,103 +28,60 @@ TEST(MemRefLayout, numContigDim) {
2828
return StridedLayoutAttr::get(&ctx, 0, s);
2929
};
3030

31-
// Create a sequence of test cases, starting with the base case of a
32-
// contiguous 2x2x2 memref with fixed dimensions and then at each step
33-
// introducing one dynamic dimension starting from the right.
34-
// With thus obtained memref, start with maximally contiguous strides
35-
// and then at each step gradually introduce discontinuity by increasing
36-
// a fixed stride size from the left to right.
37-
38-
// In these and the following test cases the intent is to achieve code
39-
// coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
40-
41-
// memref<2x2x2xf32, strided<[4,2,1]>>
42-
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
43-
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
44-
45-
// memref<2x2x2xf32, strided<[8,2,1]>>
46-
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
47-
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
48-
49-
// memref<2x2x2xf32, strided<[8,4,1]>>
50-
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
51-
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
31+
// Special case for identity maps and no explicit `strided` attribute - the
32+
// memref is entirely contiguous even if the strides cannot be determined
33+
// statically.
5234

53-
// memref<2x2x2xf32, strided<[8,4,2]>>
54-
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
55-
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
35+
// memref<?x?x?xf32>
36+
auto m0 = MemRefType::get({_, _, _}, f32);
37+
EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);
5638

57-
// memref<2x2x?xf32, strided<[?,?,1]>>
58-
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
59-
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
39+
// Conservatively assume memref is sparse everywhere if cannot get the
40+
// strides.
6041

61-
// memref<2x2x?xf32, strided<[?,?,2]>>
62-
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
63-
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
42+
// memref<2x2x2xf32, (i,j,k)->(i,k,j)>
43+
auto m1 = MemRefType::get(
44+
{2, 2, 2}, f32,
45+
AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
46+
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);
6447

65-
// memref<2x?x2xf32, strided<[?,2,1]>>
66-
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
67-
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
48+
// A base cases of a fixed memref with the usual strides.
6849

69-
// memref<2x?x2xf32, strided<[?,4,1]>>
70-
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
71-
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
50+
// memref<2x2x2xf32, strided<[4, 2, 1]>>
51+
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
52+
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);
7253

73-
// memref<2x?x2xf32, strided<[?,4,2]>>
74-
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
75-
EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
54+
// A fixed memref with a discontinuity in the rightmost dimension.
7655

77-
// memref<?x2x2xf32, strided<[4,2,1]>>
78-
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
79-
EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
80-
81-
// memref<?x2x2xf32, strided<[8,2,1]>>
82-
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
83-
EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
56+
// memref<2x2x2xf32, strided<[8, 4, 2]>>
57+
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
58+
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
8459

85-
// memref<?x2x2xf32, strided<[8,4,1]>>
86-
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
87-
EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
60+
// A fixed memref with a discontinuity in the "middle".
8861

89-
// memref<?x2x2xf32, strided<[8,4,2]>>
90-
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
91-
EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
62+
// memref<2x2x2xf32, strided<[8, 2, 1]>>
63+
auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
64+
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);
9265

93-
//
94-
// Repeat a similar process, but this time introduce a unit memref dimension
95-
// to test that strides corresponding to unit dimensions are immaterial, even
96-
// if dynamic.
97-
//
66+
// A dynamic memref where the dynamic dimension breaks continuity.
9867

99-
// memref<2x2x1xf32, strided<[2,1,2]>>
100-
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
101-
EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
68+
// memref<2x?x2xf32, strided<[4, 2, 1]>>
69+
auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
70+
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);
10271

103-
// memref<2x2x1xf32, strided<[2,1,?]>>
104-
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
105-
EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
72+
// A edge case of a dynamic memref where the dynamic dimension is the first
73+
// one.
10674

107-
// memref<2x2x1xf32, strided<[4,2,2]>>
108-
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
109-
EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
75+
// memref<?x2x2xf32, strided<[4, 2, 1]>>
76+
auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
77+
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
11078

111-
// memref<2x1x2xf32, strided<[2,4,1]>>
112-
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
113-
EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
79+
// A memref with a unit dimension. Unit dimensions do not affect continuity,
80+
// even if the corresponding stride is dynamic.
11481

11582
// memref<2x1x2xf32, strided<[2,?,1]>>
116-
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
117-
EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
118-
119-
//
120-
// Special case for identity maps and no explicit `strided` attribute - the
121-
// memref is entirely contiguous even if the strides cannot be determined
122-
// statically.
123-
//
124-
125-
// memref<?x?x?xf32>
126-
auto m19 = MemRefType::get({_, _, _}, f32);
127-
EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
83+
auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
84+
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
12885
}
12986

13087
//
@@ -140,18 +97,15 @@ TEST(MemRefLayout, contigTrailingDim) {
14097
return StridedLayoutAttr::get(&ctx, 0, s);
14198
};
14299

143-
// Pick up a random test case among the ones already present in the file and
144-
// ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
145-
// returned by `getNumContiguousTrailingDims` and `false` from that point on
146-
// up to the memref rank.
100+
// A not-entirely-continuous, not-entirely-discontinuous memref.
101+
// ensure `areTrailingDimsContiguous` returns `true` for the value
102+
// returned by `getNumContiguousTrailingDims` and `false` for the next bigger
103+
// number.
147104

148105
// memref<2x?x2xf32, strided<[?,2,1]>>
149106
auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
150107
int64_t n = m.getNumContiguousTrailingDims();
151-
for (int64_t i = 0; i <= n; ++i)
152-
EXPECT_TRUE(m.areTrailingDimsContiguous(i));
153-
154-
int64_t r = m.getRank();
155-
for (int64_t i = n + 1; i <= r; ++i)
156-
EXPECT_FALSE(m.areTrailingDimsContiguous(i));
108+
EXPECT_TRUE(m.areTrailingDimsContiguous(n));
109+
ASSERT_TRUE(n + 1 <= m.getRank());
110+
EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
157111
}

0 commit comments

Comments
 (0)