@@ -165,6 +165,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
165165
166166// -----
167167
168+ // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
169+ // 1 -> 0
170+ // 2 -> 4
171+ // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
172+ // (same as the example above, but one of the dims is scalable)
173+ // CHECK-LABEL: @shape_cast_of_transpose_scalable
174+ // CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
175+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
176+ // CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
177+ // CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
178+ func.func @shape_cast_of_transpose_scalable (%arg : vector <1 x[4 ]x4 x1 x1 xi8 >) -> vector <[4 ]x4 xi8 > {
179+ %0 = vector.transpose %arg , [1 , 0 , 3 , 4 , 2 ]
180+ : vector <1 x[4 ]x4 x1 x1 xi8 > to vector <[4 ]x1 x1 x1 x4 xi8 >
181+ %1 = vector.shape_cast %0 : vector <[4 ]x1 x1 x1 x4 xi8 > to vector <[4 ]x4 xi8 >
182+ return %1 : vector <[4 ]x4 xi8 >
183+ }
184+
185+ // -----
186+
168187// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
169188// 1 -> 2
170189// 2 -> 1
@@ -184,36 +203,10 @@ func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector
184203
185204// -----
186205
187- // Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
188- // scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
189- // CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
190- // CHECK: vector.transpose
191- // CHECK: vector.shape_cast
192- func.func @negative_shape_cast_of_transpose_scalable (%arg : vector <[4 ]x1 xi8 >) -> vector <[4 ]xi8 > {
193- %0 = vector.transpose %arg , [1 , 0 ] : vector <[4 ]x1 xi8 > to vector <1 x[4 ]xi8 >
194- %1 = vector.shape_cast %0 : vector <1 x[4 ]xi8 > to vector <[4 ]xi8 >
195- return %1 : vector <[4 ]xi8 >
196- }
197-
198- // -----
199-
200206/// +--------------------------------------------------------------------------
201207/// Tests of FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast
202208/// +--------------------------------------------------------------------------
203209
204- // The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
205- // vectors.
206- // CHECK-LABEL: @transpose_of_shape_cast_scalable
207- // CHECK: vector.shape_cast
208- // CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
209- func.func @transpose_of_shape_cast_scalable (%arg : vector <[4 ]xi8 >) -> vector <[4 ]x1 xi8 > {
210- %0 = vector.shape_cast %arg : vector <[4 ]xi8 > to vector <1 x[4 ]xi8 >
211- %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[4 ]xi8 > to vector <[4 ]x1 xi8 >
212- return %1 : vector <[4 ]x1 xi8 >
213- }
214-
215- // -----
216-
217210// A transpose that is 'order preserving' can be treated like a shape_cast.
218211// CHECK-LABEL: @transpose_of_shape_cast
219212// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
@@ -229,11 +222,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
229222
230223// -----
231224
232- // Scalable dimensions should be treated as non-unit dimensions.
233225// CHECK-LABEL: @transpose_of_shape_cast_scalable
226+ // CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
227+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
228+ // CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
229+ // CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
230+ func.func @transpose_of_shape_cast_scalable (%arg : vector <[2 ]x3 x1 x1 xi8 >) -> vector <[6 ]x1 x1 xi8 > {
231+ %0 = vector.shape_cast %arg : vector <[2 ]x3 x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
232+ %1 = vector.transpose %0 , [0 , 2 , 1 ]
233+ : vector <[6 ]x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
234+ return %1 : vector <[6 ]x1 x1 xi8 >
235+ }
236+
237+ // -----
238+
239+ // Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
240+ // (hence no folding).
241+ // CHECK-LABEL: @negative_transpose_of_shape_cast_scalable_unit
234242// CHECK: vector.shape_cast
235243// CHECK: vector.transpose
236- func.func @transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
244+ func.func @negative_transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
237245 %0 = vector.shape_cast %arg : vector <[1 ]x4 x1 xi8 > to vector <[1 ]x4 xi8 >
238246 %1 = vector.transpose %0 , [1 , 0 ] : vector <[1 ]x4 xi8 > to vector <4 x[1 ]xi8 >
239247 return %1 : vector <4 x[1 ]xi8 >
0 commit comments