@@ -195,14 +195,37 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
195
195
// CHECK-SAME: : tensor<8x33x4xf32>
196
196
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
197
197
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198
- // CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0 ]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198
+ // CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT ]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
199
199
// CHECK: %[[T2:.+]] = linalg.generic
200
200
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
201
201
// CHECK-SAME: ["parallel", "parallel", "parallel"]
202
202
// CHECK-SAME: ins(%[[T0]], %[[CST]] :
203
203
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
204
204
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
205
205
206
+ // -----
207
+
208
+ func.func @reshape_as_consumer_transpose
209
+ (%a : tensor <4 x210 x6 xf32 >)
210
+ -> tensor <2 x3 x4 x5 x6 x7 xf32 > {
211
+ %b = tensor.empty () : tensor <6 x4 x210 xf32 >
212
+ %c = linalg.transpose
213
+ ins (%a : tensor <4 x210 x6 xf32 >)
214
+ outs (%b : tensor <6 x4 x210 xf32 >) permutation = [2 , 0 , 1 ]
215
+ %d = tensor.expand_shape %c [[0 , 1 ], [2 ], [3 , 4 , 5 ]] output_shape [2 , 3 , 4 , 5 , 6 , 7 ] : tensor <6 x4 x210 xf32 > into tensor <2 x3 x4 x5 x6 x7 xf32 >
216
+ return %d : tensor <2 x3 x4 x5 x6 x7 xf32 >
217
+ }
218
+ // CHECK: func @reshape_as_consumer_transpose
219
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
220
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
221
+ // CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
222
+ // CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
223
+ // CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
224
+ // CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
225
+ // CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
226
+ // CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
227
+
228
+
206
229
// -----
207
230
208
231
#map0 = affine_map <(d0 , d1 , d2 ) -> (d2 , d0 , d1 )>
@@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
884
907
885
908
// -----
886
909
887
- func.func @linalg_transpose_reshape_producer_fusion (%arg0 : tensor <?x7 x?x8 xf32 >,
888
- %arg1 : tensor <?x?xf32 >) ->
889
- tensor <?x?xf32 >
890
- {
891
- %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 , 3 ]] :
892
- tensor <?x7 x?x8 xf32 > into tensor <?x?xf32 >
893
- %1 = linalg.transpose ins (%0 : tensor <?x?xf32 >)
894
- outs (%arg1 : tensor <?x?xf32 >) permutation = [1 , 0 ]
895
- return %1 : tensor <?x?xf32 >
910
+
911
+ func.func @reshape_as_producer_transpose
912
+ (%a : tensor <4 x5 x6 x7 x2 x3 xf32 >)
913
+ -> tensor <6 x4 x210 xf32 > {
914
+ %b = tensor.empty () : tensor <6 x4 x210 xf32 >
915
+ %c = tensor.collapse_shape %a [[0 ], [1 , 2 , 3 ], [4 , 5 ]] :
916
+ tensor <4 x5 x6 x7 x2 x3 xf32 > into tensor <4 x210 x6 xf32 >
917
+ %d = linalg.transpose
918
+ ins (%c : tensor <4 x210 x6 xf32 >)
919
+ outs (%b : tensor <6 x4 x210 xf32 >) permutation = [2 , 0 , 1 ]
920
+ return %d : tensor <6 x4 x210 xf32 >
896
921
}
897
922
898
- // CHECK: func @linalg_transpose_reshape_producer_fusion
899
- // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
900
- // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
901
- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
902
- // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
903
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
904
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
905
- // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
906
- // CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
907
- // CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
908
- // CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
909
- // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
910
- // CHECK: %[[T2:.+]] = linalg.transpose
911
- // CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
912
- // CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
913
- // CHECK-SAME: permutation = [2, 3, 0, 1]
914
- // CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
915
- // CHECK-SAME: [0, 1], [2, 3]
916
- // CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
917
- // CHECK: return %[[T3]]
923
+ // CHECK: func @reshape_as_producer_transpose
924
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
925
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty()
926
+ // CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
927
+ // CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
928
+ // CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
929
+ // CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
930
+ // CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
931
+ // CHECK: return %[[T2]] : tensor<6x4x210xf32>
932
+
918
933
919
934
// -----
920
935
0 commit comments