@@ -137,3 +137,113 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
137137 return %1 : vector <3 x3 x3 xi8 >
138138}
139139
140+
141+ // -----
142+
143+ // Test of FoldTransposeShapeCast
144+ // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
145+ // 1 -> 0
146+ // 2 -> 4
147+ // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
148+ // CHECK-LABEL: @transpose_shape_cast
149+ // CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
150+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
151+ // CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
152+ // CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
153+ func.func @transpose_shape_cast (%arg : vector <1 x4 x4 x1 x1 xi8 >) -> vector <4 x4 xi8 > {
154+ %0 = vector.transpose %arg , [1 , 0 , 3 , 4 , 2 ]
155+ : vector <1 x4 x4 x1 x1 xi8 > to vector <4 x1 x1 x1 x4 xi8 >
156+ %1 = vector.shape_cast %0 : vector <4 x1 x1 x1 x4 xi8 > to vector <4 x4 xi8 >
157+ return %1 : vector <4 x4 xi8 >
158+ }
159+
160+ // -----
161+
162+ // Test of FoldTransposeShapeCast
163+ // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
164+ // 1 -> 2
165+ // 2 -> 1
166+ // As this is not increasing (2 > 1), this transpose is not order
167+ // preserving and cannot be treated as a shape_cast.
168+ // CHECK-LABEL: @negative_transpose_shape_cast
169+ // CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
170+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
171+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
172+ // CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
173+ func.func @negative_transpose_shape_cast (%arg : vector <1 x4 x4 x1 xi8 >) -> vector <4 x4 xi8 > {
174+ %0 = vector.transpose %arg , [0 , 2 , 1 , 3 ]
175+ : vector <1 x4 x4 x1 xi8 > to vector <1 x4 x4 x1 xi8 >
176+ %1 = vector.shape_cast %0 : vector <1 x4 x4 x1 xi8 > to vector <4 x4 xi8 >
177+ return %1 : vector <4 x4 xi8 >
178+ }
179+
180+ // -----
181+
182+ // Test of FoldTransposeShapeCast
183+ // Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
184+ // scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
185+ // CHECK-LABEL: @negative_transpose_shape_cast_scalable
186+ // CHECK: vector.transpose
187+ // CHECK: vector.shape_cast
188+ func.func @negative_transpose_shape_cast_scalable (%arg : vector <[4 ]x1 xi8 >) -> vector <[4 ]xi8 > {
189+ %0 = vector.transpose %arg , [1 , 0 ] : vector <[4 ]x1 xi8 > to vector <1 x[4 ]xi8 >
190+ %1 = vector.shape_cast %0 : vector <1 x[4 ]xi8 > to vector <[4 ]xi8 >
191+ return %1 : vector <[4 ]xi8 >
192+ }
193+
194+ // -----
195+
196+ // Test of shape_cast folding.
197+ // The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
198+ // vectors.
199+ // CHECK-LABEL: @shape_cast_transpose_scalable
200+ // CHECK: vector.shape_cast
201+ // CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
202+ func.func @shape_cast_transpose_scalable (%arg : vector <[4 ]xi8 >) -> vector <[4 ]x1 xi8 > {
203+ %0 = vector.shape_cast %arg : vector <[4 ]xi8 > to vector <1 x[4 ]xi8 >
204+ %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[4 ]xi8 > to vector <[4 ]x1 xi8 >
205+ return %1 : vector <[4 ]x1 xi8 >
206+ }
207+
208+ // -----
209+
210+ // Test of shape_cast folding.
211+ // A transpose that is 'order preserving' can be treated like a shape_cast.
212+ // CHECK-LABEL: @shape_cast_transpose
213+ // CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
214+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
215+ // CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
216+ // CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
217+ func.func @shape_cast_transpose (%arg : vector <2 x3 x1 x1 xi8 >) -> vector <6 x1 x1 xi8 > {
218+ %0 = vector.shape_cast %arg : vector <2 x3 x1 x1 xi8 > to vector <6 x1 x1 xi8 >
219+ %1 = vector.transpose %0 , [0 , 2 , 1 ]
220+ : vector <6 x1 x1 xi8 > to vector <6 x1 x1 xi8 >
221+ return %1 : vector <6 x1 x1 xi8 >
222+ }
223+
224+ // -----
225+
226+ // Test of shape_cast folding.
227+ // Scalable dimensions should be treated as non-unit dimensions.
228+ // CHECK-LABEL: @shape_cast_transpose_scalable
229+ // CHECK: vector.shape_cast
230+ // CHECK: vector.transpose
231+ func.func @shape_cast_transpose_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
232+ %0 = vector.shape_cast %arg : vector <[1 ]x4 x1 xi8 > to vector <[1 ]x4 xi8 >
233+ %1 = vector.transpose %0 , [1 , 0 ] : vector <[1 ]x4 xi8 > to vector <4 x[1 ]xi8 >
234+ return %1 : vector <4 x[1 ]xi8 >
235+ }
236+
237+ // -----
238+
239+ // Test of shape_cast (not) folding.
240+ // CHECK-LABEL: @negative_shape_cast_transpose
241+ // CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
242+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
243+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
244+ // CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
245+ func.func @negative_shape_cast_transpose (%arg : vector <6 xi8 >) -> vector <2 x3 xi8 > {
246+ %0 = vector.shape_cast %arg : vector <6 xi8 > to vector <3 x2 xi8 >
247+ %1 = vector.transpose %0 , [1 , 0 ] : vector <3 x2 xi8 > to vector <2 x3 xi8 >
248+ return %1 : vector <2 x3 xi8 >
249+ }
0 commit comments