1515using namespace mlir ;
1616using namespace mlir ::memref;
1717
18+ //
19+ // Test the correctness of `memref::getNumContiguousTrailingDims`
20+ //
1821TEST (MemRefLayout, numContigDim) {
1922 MLIRContext ctx;
2023 OpBuilder b (&ctx);
@@ -25,79 +28,108 @@ TEST(MemRefLayout, numContigDim) {
2528 return StridedLayoutAttr::get (&ctx, 0 , s);
2629 };
2730
28- // memref<2x2x2xf32, strided<[4,2,1]>
31+ // Create a sequence of test cases, starting with the base case of a
32+ // contiguous 2x2x2 memref with fixed dimensions and then at each step
33+ // introducing one dynamic dimension starting from the right.
34+ // With thus obtained memref, start with maximally contiguous strides
35+ // and then at each step gradually introduce discontinuity by increasing
36+ // a fixed stride size from the left to right.
37+
38+ // In these and the following test cases the intent is to achieve code
39+ // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
40+
41+ // memref<2x2x2xf32, strided<[4,2,1]>>
2942 auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
3043 EXPECT_EQ (m1.getNumContiguousTrailingDims (), 3 );
3144
32- // memref<2x2x2xf32, strided<[8,2,1]>
45+ // memref<2x2x2xf32, strided<[8,2,1]>>
3346 auto m2 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
3447 EXPECT_EQ (m2.getNumContiguousTrailingDims (), 2 );
3548
36- // memref<2x2x2xf32, strided<[8,4,1]>
49+ // memref<2x2x2xf32, strided<[8,4,1]>>
3750 auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
3851 EXPECT_EQ (m3.getNumContiguousTrailingDims (), 1 );
3952
40- // memref<2x2x2xf32, strided<[8,4,2]>
53+ // memref<2x2x2xf32, strided<[8,4,2]>>
4154 auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
4255 EXPECT_EQ (m4.getNumContiguousTrailingDims (), 0 );
4356
44- // memref<2x2x?xf32, strided<[?,?,1]>
57+ // memref<2x2x?xf32, strided<[?,?,1]>>
4558 auto m5 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 1 }));
4659 EXPECT_EQ (m5.getNumContiguousTrailingDims (), 1 );
4760
48- // memref<2x2x?xf32, strided<[?,?,2]>
61+ // memref<2x2x?xf32, strided<[?,?,2]>>
4962 auto m6 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 2 }));
5063 EXPECT_EQ (m6.getNumContiguousTrailingDims (), 0 );
5164
52- // memref<2x?x2xf32, strided<[?,2,1]>
65+ // memref<2x?x2xf32, strided<[?,2,1]>>
5366 auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
5467 EXPECT_EQ (m7.getNumContiguousTrailingDims (), 2 );
5568
56- // memref<2x?x2xf32, strided<[?,4,1]>
69+ // memref<2x?x2xf32, strided<[?,4,1]>>
5770 auto m8 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 1 }));
5871 EXPECT_EQ (m8.getNumContiguousTrailingDims (), 1 );
5972
60- // memref<2x?x2xf32, strided<[?,4,2]>
73+ // memref<2x?x2xf32, strided<[?,4,2]>>
6174 auto m9 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 2 }));
6275 EXPECT_EQ (m9.getNumContiguousTrailingDims (), 0 );
6376
64- // memref<?x2x2xf32, strided<[4,2,1]>
77+ // memref<?x2x2xf32, strided<[4,2,1]>>
6578 auto m10 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
6679 EXPECT_EQ (m10.getNumContiguousTrailingDims (), 3 );
6780
68- // memref<?x2x2xf32, strided<[8,2,1]>
81+ // memref<?x2x2xf32, strided<[8,2,1]>>
6982 auto m11 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
7083 EXPECT_EQ (m11.getNumContiguousTrailingDims (), 2 );
7184
72- // memref<?x2x2xf32, strided<[8,4,1]>
85+ // memref<?x2x2xf32, strided<[8,4,1]>>
7386 auto m12 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
7487 EXPECT_EQ (m12.getNumContiguousTrailingDims (), 1 );
7588
76- // memref<?x2x2xf32, strided<[8,4,2]>
89+ // memref<?x2x2xf32, strided<[8,4,2]>>
7790 auto m13 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
7891 EXPECT_EQ (m13.getNumContiguousTrailingDims (), 0 );
7992
80- // memref<2x2x1xf32, strided<[2,1,2]>
93+ //
94+ // Repeat a similar process, but this time introduce a unit memref dimension
95+ // to test that strides corresponding to unit dimensions are immaterial, even
96+ // if dynamic.
97+ //
98+
99+ // memref<2x2x1xf32, strided<[2,1,2]>>
81100 auto m14 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , 2 }));
82101 EXPECT_EQ (m14.getNumContiguousTrailingDims (), 3 );
83102
84- // memref<2x2x1xf32, strided<[2,1,?]>
103+ // memref<2x2x1xf32, strided<[2,1,?]>>
85104 auto m15 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({2 , 1 , _}));
86105 EXPECT_EQ (m15.getNumContiguousTrailingDims (), 3 );
87106
88- // memref<2x2x1xf32, strided<[4,2,2]>
107+ // memref<2x2x1xf32, strided<[4,2,2]>>
89108 auto m16 = MemRefType::get ({2 , 2 , 1 }, f32 , strided ({4 , 2 , 2 }));
90109 EXPECT_EQ (m16.getNumContiguousTrailingDims (), 1 );
91110
92- // memref<2x1x2xf32, strided<[2,4,1]>
111+ // memref<2x1x2xf32, strided<[2,4,1]>>
93112 auto m17 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , 4 , 1 }));
94113 EXPECT_EQ (m17.getNumContiguousTrailingDims (), 3 );
95114
96- // memref<2x1x2xf32, strided<[2,?,1]>
115+ // memref<2x1x2xf32, strided<[2,?,1]>>
97116 auto m18 = MemRefType::get ({2 , 1 , 2 }, f32 , strided ({2 , _, 1 }));
98117 EXPECT_EQ (m18.getNumContiguousTrailingDims (), 3 );
118+
119+ //
120+ // Special case for identity maps and no explicit `strided` attribute - the
121+ // memref is entirely contiguous even if the strides cannot be determined
122+ // statically.
123+ //
124+
125+ // memref<?x?x?xf32>
126+ auto m19 = MemRefType::get ({_, _, _}, f32 );
127+ EXPECT_EQ (m19.getNumContiguousTrailingDims (), 3 );
99128}
100129
130+ //
131+ // Test the member function `memref::areTrailingDimsContiguous`
132+ //
101133TEST (MemRefLayout, contigTrailingDim) {
102134 MLIRContext ctx;
103135 OpBuilder b (&ctx);
@@ -108,103 +140,18 @@ TEST(MemRefLayout, contigTrailingDim) {
108140 return StridedLayoutAttr::get (&ctx, 0 , s);
109141 };
110142
111- // memref<2x2x2xf32, strided<[4,2,1]>
112- auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
113- EXPECT_TRUE (m1.areTrailingDimsContiguous (1 ));
114- EXPECT_TRUE (m1.areTrailingDimsContiguous (2 ));
115- EXPECT_TRUE (m1.areTrailingDimsContiguous (3 ));
116-
117- // memref<2x2x2xf32, strided<[8,2,1]>
118- auto m2 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
119- EXPECT_TRUE (m2.areTrailingDimsContiguous (1 ));
120- EXPECT_TRUE (m2.areTrailingDimsContiguous (2 ));
121- EXPECT_FALSE (m2.areTrailingDimsContiguous (3 ));
122-
123- // memref<2x2x2xf32, strided<[8,4,1]>
124- auto m3 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
125- EXPECT_TRUE (m3.areTrailingDimsContiguous (1 ));
126- EXPECT_FALSE (m3.areTrailingDimsContiguous (2 ));
127- EXPECT_FALSE (m3.areTrailingDimsContiguous (3 ));
128-
129- // memref<2x2x2xf32, strided<[8,4,2]>
130- auto m4 = MemRefType::get ({2 , 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
131- EXPECT_FALSE (m4.areTrailingDimsContiguous (1 ));
132- EXPECT_FALSE (m4.areTrailingDimsContiguous (2 ));
133- EXPECT_FALSE (m4.areTrailingDimsContiguous (3 ));
134-
135- // memref<2x2x?xf32, strided<[?,?,1]>
136- auto m5 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 1 }));
137- EXPECT_TRUE (m5.areTrailingDimsContiguous (1 ));
138- EXPECT_FALSE (m5.areTrailingDimsContiguous (2 ));
139- EXPECT_FALSE (m5.areTrailingDimsContiguous (3 ));
140-
141- // memref<2x2x?xf32, strided<[?,?,2]>
142- auto m6 = MemRefType::get ({2 , 2 , _}, f32 , strided ({_, _, 2 }));
143- EXPECT_FALSE (m6.areTrailingDimsContiguous (1 ));
144- EXPECT_FALSE (m6.areTrailingDimsContiguous (2 ));
145- EXPECT_FALSE (m6.areTrailingDimsContiguous (3 ));
146-
147- // memref<2x?x2xf32, strided<[?,2,1]>
148- auto m7 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
149- EXPECT_TRUE (m7.areTrailingDimsContiguous (1 ));
150- EXPECT_TRUE (m7.areTrailingDimsContiguous (2 ));
151- EXPECT_FALSE (m7.areTrailingDimsContiguous (3 ));
152-
153- // memref<2x?x2xf32, strided<[?,4,1]>
154- auto m8 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 1 }));
155- EXPECT_TRUE (m8.areTrailingDimsContiguous (1 ));
156- EXPECT_FALSE (m8.areTrailingDimsContiguous (2 ));
157- EXPECT_FALSE (m8.areTrailingDimsContiguous (3 ));
158-
159- // memref<2x?x2xf32, strided<[?,4,2]>
160- auto m9 = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 4 , 2 }));
161- EXPECT_FALSE (m9.areTrailingDimsContiguous (1 ));
162- EXPECT_FALSE (m9.areTrailingDimsContiguous (2 ));
163- EXPECT_FALSE (m9.areTrailingDimsContiguous (3 ));
164-
165- // memref<?x2x2xf32, strided<[4,2,1]>
166- auto m10 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({4 , 2 , 1 }));
167- EXPECT_TRUE (m10.areTrailingDimsContiguous (1 ));
168- EXPECT_TRUE (m10.areTrailingDimsContiguous (2 ));
169- EXPECT_TRUE (m10.areTrailingDimsContiguous (3 ));
170-
171- // memref<?x2x2xf32, strided<[8,2,1]>
172- auto m11 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 2 , 1 }));
173- EXPECT_TRUE (m11.areTrailingDimsContiguous (1 ));
174- EXPECT_TRUE (m11.areTrailingDimsContiguous (2 ));
175- EXPECT_FALSE (m11.areTrailingDimsContiguous (3 ));
176-
177- // memref<?x2x2xf32, strided<[8,4,1]>
178- auto m12 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 1 }));
179- EXPECT_TRUE (m12.areTrailingDimsContiguous (1 ));
180- EXPECT_FALSE (m12.areTrailingDimsContiguous (2 ));
181- EXPECT_FALSE (m12.areTrailingDimsContiguous (3 ));
143+ // Pick up a random test case among the ones already present in the file and
144+ // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
145+ // returned by `getNumContiguousTrailingDims` and `false` from that point on
146+ // up to the memref rank.
182147
183- // memref<?x2x2xf32, strided<[8,4,2]>
184- auto m13 = MemRefType::get ({_, 2 , 2 }, f32 , strided ({8 , 4 , 2 }));
185- EXPECT_FALSE (m13.areTrailingDimsContiguous (1 ));
186- EXPECT_FALSE (m13.areTrailingDimsContiguous (2 ));
187- EXPECT_FALSE (m13.areTrailingDimsContiguous (3 ));
188- }
189-
190- TEST (MemRefLayout, identityMaps) {
191- MLIRContext ctx;
192- OpBuilder b (&ctx);
148+ // memref<2x?x2xf32, strided<[?,2,1]>>
149+ auto m = MemRefType::get ({2 , _, 2 }, f32 , strided ({_, 2 , 1 }));
150+ int64_t n = m.getNumContiguousTrailingDims ();
151+ for (int64_t i = 0 ; i <= n; ++i)
152+ EXPECT_TRUE (m.areTrailingDimsContiguous (i));
193153
194- const int64_t _ = ShapedType::kDynamic ;
195- const FloatType f32 = b.getF32Type ();
196-
197- // memref<2x2x2xf32>
198- auto m1 = MemRefType::get ({2 , 2 , 2 }, f32 );
199- EXPECT_EQ (m1.getNumContiguousTrailingDims (), 3 );
200- EXPECT_TRUE (m1.areTrailingDimsContiguous (1 ));
201- EXPECT_TRUE (m1.areTrailingDimsContiguous (2 ));
202- EXPECT_TRUE (m1.areTrailingDimsContiguous (3 ));
203-
204- // memref<?x?x?xf32>
205- auto m2 = MemRefType::get ({_, _, _}, f32 );
206- EXPECT_EQ (m2.getNumContiguousTrailingDims (), 3 );
207- EXPECT_TRUE (m2.areTrailingDimsContiguous (1 ));
208- EXPECT_TRUE (m2.areTrailingDimsContiguous (2 ));
209- EXPECT_TRUE (m2.areTrailingDimsContiguous (3 ));
154+ int64_t r = m.getRank ();
155+ for (int64_t i = n + 1 ; i <= r; ++i)
156+ EXPECT_FALSE (m.areTrailingDimsContiguous (i));
210157}
0 commit comments