Skip to content

Commit 20ae70c

Browse files
authored
Merge pull request #611 from Xilinx/rogarcia.optimize_concat_to_avoid_multiple_concatenations
feat: Do not inline operand concat inputs on folder if it is being used by another operation
2 parents 42131ee + 519e096 commit 20ae70c

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,14 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
16931693
if (getAxis() != producer.getAxis())
16941694
continue;
16951695

1696+
// If there are multiple uses of this operand concat and they are different
1697+
// operations, this means that operand concat will have to happen, so do not
1698+
// add its operands to us to avoid repeating data concatenation
1699+
const bool allConcatUsersAreThisConcat = llvm::all_of(
1700+
producer->getUsers(), [&](Operation *user) { return *this == user; });
1701+
if (!allConcatUsersAreThisConcat)
1702+
continue;
1703+
16961704
// Replace the original operand with all incoming operands
16971705
foundFoldableConcat = true;
16981706
concatOperands.pop_back();

mlir/test/Dialect/Tosa/fold_concats.mlir

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,43 @@ func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> {
6262

6363
// -----
6464

65+
func.func @concat_multiple_users(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> (tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>) {
66+
%tmp = tensor.empty() : tensor<1x1x7x7xf32>
67+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
68+
%1 = tosa.concat %tmp, %0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x3x7x7xf32>
69+
%2 = tosa.add %0, %0 : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32>
70+
return %1, %2 : tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>
71+
}
72+
73+
// CHECK-LABEL: func.func @concat_multiple_users
74+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>) -> (tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>) {
75+
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1x1x7x7xf32>
76+
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
77+
// CHECK-NOT: separator of consecutive DAGs
78+
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x3x7x7xf32>
79+
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.add [[VAR_1_]], [[VAR_1_]] : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32>
80+
// CHECK: return [[VAR_2_]], [[VAR_3_]] : tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>
81+
// CHECK: }
82+
83+
// -----
84+
85+
func.func @concat_diamond_shape(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>, %arg2: tensor<1x1x7x7xf32>, %arg3: tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> {
86+
%tmp = tensor.empty() : tensor<1x1x7x7xf32>
87+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
88+
%1 = tosa.concat %0, %arg2 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x3x7x7xf32>
89+
%2 = tosa.concat %0, %arg3 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x3x7x7xf32>
90+
%3 = tosa.concat %1, %2 {axis = 1 : i32} : (tensor<1x3x7x7xf32>, tensor<1x3x7x7xf32>) -> tensor<1x6x7x7xf32>
91+
return %3 : tensor<1x6x7x7xf32>
92+
}
93+
94+
// CHECK-LABEL: func.func @concat_diamond_shape
95+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_2_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_3_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> {
96+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_0_]], [[PARAM_1_]], [[PARAM_3_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32>
97+
// CHECK: return [[VAR_0_]] : tensor<1x6x7x7xf32>
98+
// CHECK: }
99+
100+
// -----
101+
65102
func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
66103
%0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
67104
%1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
@@ -91,4 +128,4 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
91128
// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_1_]], [[VAR_0_]] : (tensor<1x2x4x8xf32>, !tosa.shape<4>) -> tensor<1x2x8x8xf32>
92129
// CHECK: [[VAR_2_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
93130
// CHECK: return [[VAR_2_]] : tensor<1x4x8x8xf32>
94-
// CHECK: }
131+
// CHECK: }

0 commit comments

Comments
 (0)