@@ -195,14 +195,37 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
195195// CHECK-SAME: : tensor<8x33x4xf32>
196196// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
197197// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198- // CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0 ]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198+ // CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT ]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
199199// CHECK: %[[T2:.+]] = linalg.generic
200200// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
201201// CHECK-SAME: ["parallel", "parallel", "parallel"]
202202// CHECK-SAME: ins(%[[T0]], %[[CST]] :
203203// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
204204// CHECK: return %[[T2]] : tensor<8x33x4xf32>
205205
206+ // -----
207+
208+ func.func @reshape_as_consumer_transpose
209+ (%a : tensor <4 x210 x6 xf32 >)
210+ -> tensor <2 x3 x4 x5 x6 x7 xf32 > {
211+ %b = tensor.empty () : tensor <6 x4 x210 xf32 >
212+ %c = linalg.transpose
213+ ins (%a : tensor <4 x210 x6 xf32 >)
214+ outs (%b : tensor <6 x4 x210 xf32 >) permutation = [2 , 0 , 1 ]
215+ %d = tensor.expand_shape %c [[0 , 1 ], [2 ], [3 , 4 , 5 ]] output_shape [2 , 3 , 4 , 5 , 6 , 7 ] : tensor <6 x4 x210 xf32 > into tensor <2 x3 x4 x5 x6 x7 xf32 >
216+ return %d : tensor <2 x3 x4 x5 x6 x7 xf32 >
217+ }
218+ // CHECK: func @reshape_as_consumer_transpose
219+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
220+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
221+ // CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
222+ // CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
223+ // CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
224+ // CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
225+ // CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
226+ // CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
227+
228+
206229// -----
207230
208231#map0 = affine_map <(d0 , d1 , d2 ) -> (d2 , d0 , d1 )>
@@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
884907
885908// -----
886909
887- func.func @linalg_transpose_reshape_producer_fusion (%arg0 : tensor <?x7 x?x8 xf32 >,
888- %arg1 : tensor <?x?xf32 >) ->
889- tensor <?x?xf32 >
890- {
891- %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 , 3 ]] :
892- tensor <?x7 x?x8 xf32 > into tensor <?x?xf32 >
893- %1 = linalg.transpose ins (%0 : tensor <?x?xf32 >)
894- outs (%arg1 : tensor <?x?xf32 >) permutation = [1 , 0 ]
895- return %1 : tensor <?x?xf32 >
910+
911+ func.func @reshape_as_producer_transpose
912+ (%a : tensor <4 x5 x6 x7 x2 x3 xf32 >)
913+ -> tensor <6 x4 x210 xf32 > {
914+ %b = tensor.empty () : tensor <6 x4 x210 xf32 >
915+ %c = tensor.collapse_shape %a [[0 ], [1 , 2 , 3 ], [4 , 5 ]] :
916+ tensor <4 x5 x6 x7 x2 x3 xf32 > into tensor <4 x210 x6 xf32 >
917+ %d = linalg.transpose
918+ ins (%c : tensor <4 x210 x6 xf32 >)
919+ outs (%b : tensor <6 x4 x210 xf32 >) permutation = [2 , 0 , 1 ]
920+ return %d : tensor <6 x4 x210 xf32 >
896921}
897922
898- // CHECK: func @linalg_transpose_reshape_producer_fusion
899- // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
900- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
901- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
902- // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
903- // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
904- // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
905- // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
906- // CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
907- // CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
908- // CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
909- // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
910- // CHECK: %[[T2:.+]] = linalg.transpose
911- // CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
912- // CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
913- // CHECK-SAME: permutation = [2, 3, 0, 1]
914- // CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
915- // CHECK-SAME: [0, 1], [2, 3]
916- // CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
917- // CHECK: return %[[T3]]
923+ // CHECK: func @reshape_as_producer_transpose
924+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
925+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
926+ // CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
927+ // CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
928+ // CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
929+ // CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
930+ // CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
931+ // CHECK: return %[[T2]] : tensor<6x4x210xf32>
932+
918933
919934// -----
920935
0 commit comments