@@ -1251,6 +1251,29 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
12511251
12521252// -----
12531253
1254+ func.func @compose_expand_of_collapse_last_two_dims (%arg0: tensor <?x64 x1 xf32 >) -> tensor <?x384 xf32 > {
1255+ %collapsed = tensor.collapse_shape %arg0 [[0 , 1 , 2 ]] : tensor <?x64 x1 xf32 > into tensor <?xf32 >
1256+ %c0 = arith.constant 0 : index
1257+ %dim = tensor.dim %collapsed , %c0 : tensor <?xf32 >
1258+ %c384 = arith.constant 384 : index
1259+ %div = arith.divui %dim , %c384 : index
1260+ %expanded = tensor.expand_shape %collapsed [[0 , 1 ]] output_shape [%div , 384 ] : tensor <?xf32 > into tensor <?x384 xf32 >
1261+ return %expanded : tensor <?x384 xf32 >
1262+ }
1263+ // CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
1264+ // CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
1265+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1266+ // CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
1267+ // CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
1268+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1269+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1270+ // CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1271+ // CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
1272+ // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%1, 384] : tensor<?xf32> into tensor<?x384xf32>
1273+ // CHECK: return %[[RESULT]]
1274+
1275+ // -----
1276+
12541277func.func @compose_expand_of_collapse (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 xf32 >)
12551278 -> tensor <24 x5 x42 x8 xf32 > {
12561279 %0 = tensor.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 , 5 , 6 ]]
0 commit comments