@@ -28,103 +28,60 @@ TEST(MemRefLayout, numContigDim) {
2828 return StridedLayoutAttr::get (&ctx, 0 , s);
2929 };
3030
31- // Create a sequence of test cases, starting with the base case of a
32- // contiguous 2x2x2 memref with fixed dimensions and then at each step
33- // introducing one dynamic dimension starting from the right.
34- // With thus obtained memref, start with maximally contiguous strides
35- // and then at each step gradually introduce discontinuity by increasing
36- // a fixed stride size from the left to right.
37-
38- // In these and the following test cases the intent is to achieve code
39- // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
40-
41- // memref<2x2x2xf32, strided<[4,2,1]>>
42- auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
43- EXPECT_EQ (m1.getNumContiguousTrailingDims (), 3 );
44-
45- // memref<2x2x2xf32, strided<[8,2,1]>>
46- auto m2 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
47- EXPECT_EQ (m2.getNumContiguousTrailingDims (), 2 );
48-
49- // memref<2x2x2xf32, strided<[8,4,1]>>
50- auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
51- EXPECT_EQ (m3.getNumContiguousTrailingDims (), 1 );
31+ // Special case for identity maps and no explicit `strided` attribute - the
32+ // memref is entirely contiguous even if the strides cannot be determined
33+ // statically.
5234
53- // memref<2x2x2xf32, strided<[8,4,2]> >
54- auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({ 8 , 4 , 2 }) );
55- EXPECT_EQ (m4 .getNumContiguousTrailingDims (), 0 );
35+ // memref<?x?x?xf32 >
36+ auto m0 = MemRefType::get ({_, _, _ }, f32 );
37+ EXPECT_EQ (m0 .getNumContiguousTrailingDims (), 3 );
5638
57- // memref<2x2x?xf32, strided<[?,?,1]>>
58- auto m5 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 1 }));
59- EXPECT_EQ (m5.getNumContiguousTrailingDims (), 1 );
39+ // Conservatively assume memref is sparse everywhere if cannot get the
40+ // strides.
6041
61- // memref<2x2x?xf32, strided<[?,?,2]>>
62- auto m6 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 2 }));
63- EXPECT_EQ (m6.getNumContiguousTrailingDims (), 0 );
42+ // memref<2x2x2xf32, (i,j,k)->(i,k,j)>
43+ auto m1 = MemRefType::get (
44+ {2 , 2 , 2 }, f32 ,
45+ AffineMap::getPermutationMap (ArrayRef<int64_t >{0 , 2 , 1 }, &ctx));
46+ EXPECT_EQ (m1.getNumContiguousTrailingDims (), 0 );
6447
65- // memref<2x?x2xf32, strided<[?,2,1]>>
66- auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
67- EXPECT_EQ (m7.getNumContiguousTrailingDims (), 2 );
48+ // A base cases of a fixed memref with the usual strides.
6849
69- // memref<2x?x2xf32 , strided<[?,4, 1]>>
70- auto m8 = MemRefType::get ({2 , _ , 2 }, f32 , strided ({_, 4 , 1 }));
71- EXPECT_EQ (m8 .getNumContiguousTrailingDims (), 1 );
50+ // memref<2x2x2xf32 , strided<[4, 2, 1]>>
51+ auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
52+ EXPECT_EQ (m3 .getNumContiguousTrailingDims (), 3 );
7253
73- // memref<2x?x2xf32, strided<[?,4,2]>>
74- auto m9 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 2 }));
75- EXPECT_EQ (m9.getNumContiguousTrailingDims (), 0 );
54+ // A fixed memref with a discontinuity in the rightmost dimension.
7655
77- // memref<?x2x2xf32, strided<[4,2,1]>>
78- auto m10 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
79- EXPECT_EQ (m10.getNumContiguousTrailingDims (), 3 );
80-
81- // memref<?x2x2xf32, strided<[8,2,1]>>
82- auto m11 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
83- EXPECT_EQ (m11.getNumContiguousTrailingDims (), 2 );
56+ // memref<2x2x2xf32, strided<[8, 4, 2]>>
57+ auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
58+ EXPECT_EQ (m4.getNumContiguousTrailingDims (), 0 );
8459
85- // memref<?x2x2xf32, strided<[8,4,1]>>
86- auto m12 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
87- EXPECT_EQ (m12.getNumContiguousTrailingDims (), 1 );
60+ // A fixed memref with a discontinuity in the "middle".
8861
89- // memref<?x2x2xf32 , strided<[8,4,2 ]>>
90- auto m13 = MemRefType::get ({_ , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
91- EXPECT_EQ (m13 .getNumContiguousTrailingDims (), 0 );
62+ // memref<2x2x2xf32 , strided<[8, 2, 1 ]>>
63+ auto m5 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
64+ EXPECT_EQ (m5 .getNumContiguousTrailingDims (), 2 );
9265
93- //
94- // Repeat a similar process, but this time introduce a unit memref dimension
95- // to test that strides corresponding to unit dimensions are immaterial, even
96- // if dynamic.
97- //
66+ // A dynamic memref where the dynamic dimension breaks continuity.
9867
99- // memref<2x2x1xf32 , strided<[2,1,2 ]>>
100- auto m14 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , 2 }));
101- EXPECT_EQ (m14 .getNumContiguousTrailingDims (), 3 );
68+ // memref<2x?x2xf32 , strided<[4, 2, 1 ]>>
69+ auto m6 = MemRefType::get ({2 , _, 2 }, f32 , strided ({4 , 2 , 1 }));
70+ EXPECT_EQ (m6 .getNumContiguousTrailingDims (), 2 );
10271
103- // memref<2x2x1xf32, strided<[2,1,?]>>
104- auto m15 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , _}));
105- EXPECT_EQ (m15.getNumContiguousTrailingDims (), 3 );
72+ // A edge case of a dynamic memref where the dynamic dimension is the first
73+ // one.
10674
107- // memref<2x2x1xf32 , strided<[4,2,2 ]>>
108- auto m16 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({4 , 2 , 2 }));
109- EXPECT_EQ (m16 .getNumContiguousTrailingDims (), 1 );
75+ // memref<?x2x2xf32 , strided<[4, 2, 1 ]>>
76+ auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({4 , 2 , 1 }));
77+ EXPECT_EQ (m7 .getNumContiguousTrailingDims (), 2 );
11078
111- // memref<2x1x2xf32, strided<[2,4,1]>>
112- auto m17 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , 4 , 1 }));
113- EXPECT_EQ (m17.getNumContiguousTrailingDims (), 3 );
79+ // A memref with a unit dimension. Unit dimensions do not affect continuity,
80+ // even if the corresponding stride is dynamic.
11481
11582 // memref<2x1x2xf32, strided<[2,?,1]>>
116- auto m18 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , _, 1 }));
117- EXPECT_EQ (m18.getNumContiguousTrailingDims (), 3 );
118-
119- //
120- // Special case for identity maps and no explicit `strided` attribute - the
121- // memref is entirely contiguous even if the strides cannot be determined
122- // statically.
123- //
124-
125- // memref<?x?x?xf32>
126- auto m19 = MemRefType::get ({_, _, _}, f32 );
127- EXPECT_EQ (m19.getNumContiguousTrailingDims (), 3 );
83+ auto m8 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , _, 1 }));
84+ EXPECT_EQ (m8.getNumContiguousTrailingDims (), 3 );
12885}
12986
13087//
@@ -140,18 +97,15 @@ TEST(MemRefLayout, contigTrailingDim) {
14097 return StridedLayoutAttr::get (&ctx, 0 , s);
14198 };
14299
143- // Pick up a random test case among the ones already present in the file and
144- // ensure `areTrailingDimsContiguous(k) ` returns `true` up to the value
145- // returned by `getNumContiguousTrailingDims` and `false` from that point on
146- // up to the memref rank .
100+ // A not-entirely-continuous, not-entirely-discontinuous memref.
101+ // ensure `areTrailingDimsContiguous` returns `true` for the value
102+ // returned by `getNumContiguousTrailingDims` and `false` for the next bigger
103+ // number .
147104
148105 // memref<2x?x2xf32, strided<[?,2,1]>>
149106 auto m = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
150107 int64_t n = m.getNumContiguousTrailingDims ();
151- for (int64_t i = 0 ; i <= n; ++i)
152- EXPECT_TRUE (m.areTrailingDimsContiguous (i));
153-
154- int64_t r = m.getRank ();
155- for (int64_t i = n + 1 ; i <= r; ++i)
156- EXPECT_FALSE (m.areTrailingDimsContiguous (i));
108+ EXPECT_TRUE (m.areTrailingDimsContiguous (n));
109+ ASSERT_TRUE (n + 1 <= m.getRank ());
110+ EXPECT_FALSE (m.areTrailingDimsContiguous (n + 1 ));
157111}
0 commit comments