@@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
753753 return %1 : tensor <?x?x4 x5 xf32 >
754754}
755755
756- // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
757756// CHECK: func @linalg_add_reshape_consumer_fusion
758757// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
759758// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
774773// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
775774// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
776775// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
777- // CHECK: %[[T4:.+]] = linalg.generic
778- // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
779- // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
776+ // CHECK: %[[T4:.+]] = linalg.add
780777// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
781778// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>)
782779// CHECK: return %[[T4]] : tensor<?x?x4x5xf32>
783780
784781// -----
785782
786- #map0 = affine_map <(d0 , d1 , d2 ) -> (d2 , d0 )>
787- #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
788- #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
789783func.func @linalg_add_reshape_producer_fusion (%arg0 : tensor <?x7 x?x8 xf32 >,
790784 %arg1 : tensor <?x?xf32 >,
791785 %arg2 : tensor <?x?xf32 >) ->
@@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
798792 return %1 : tensor <?x?xf32 >
799793}
800794
801- // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
802795// CHECK: func @linalg_add_reshape_producer_fusion
803796// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
804797// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -817,16 +810,50 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
817810// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
818811// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
819812// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
820- // CHECK: %[[T3:.+]] = linalg.generic
821- // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
822- // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
813+ // CHECK: %[[T3:.+]] = linalg.add
823814// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
824815// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>)
825816// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
826817// CHECK-SAME: [0, 1], [2, 3]
827818// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
828819// CHECK: return %[[T4]]
829820
821+ // -----
822+
823+ func.func @linalg_transpose_reshape_producer_fusion (%arg0 : tensor <?x7 x?x8 xf32 >,
824+ %arg1 : tensor <?x?xf32 >) ->
825+ tensor <?x?xf32 >
826+ {
827+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 , 3 ]] :
828+ tensor <?x7 x?x8 xf32 > into tensor <?x?xf32 >
829+ %1 = linalg.transpose ins (%0 : tensor <?x?xf32 >)
830+ outs (%arg1 : tensor <?x?xf32 >) permutation = [1 , 0 ]
831+ return %1 : tensor <?x?xf32 >
832+ }
833+
834+ // CHECK: func @linalg_transpose_reshape_producer_fusion
835+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
836+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
837+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
838+ // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
839+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
840+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
841+ // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
842+ // CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
843+ // CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
844+ // CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
845+ // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
846+ // CHECK: %[[T2:.+]] = linalg.transpose
847+ // CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
848+ // CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
849+ // CHECK-SAME: permutation = [2, 3, 0, 1]
850+ // CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
851+ // CHECK-SAME: [0, 1], [2, 3]
852+ // CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
853+ // CHECK: return %[[T3]]
854+
855+
856+
830857// -----
831858
832859func.func @fuse_by_expanding_pad (%arg0 : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >) -> tensor <8 x12 x17 x336 x14 xi32 > {
0 commit comments