Skip to content

Commit 55654e9

Browse files
committed
Canonicalize 'self-concats' to tile
1 parent 1656bbb commit 55654e9

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
6060
}
6161
};
6262

63+
struct SelfConcatToTile : public OpRewritePattern<tosa::ConcatOp> {
64+
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
65+
66+
LogicalResult matchAndRewrite(tosa::ConcatOp concatOp,
67+
PatternRewriter &rewriter) const override {
68+
if (llvm::all_equal(concatOp->getUsers())) {
69+
const auto concatUser = llvm::dyn_cast<tosa::ConcatOp>(
70+
concatOp->getUses().begin()->getOwner());
71+
if (concatUser) {
72+
// Try folding the concat into its consumer before rewriting it to a
73+
// tile.
74+
SmallVector<Value> replacementValues;
75+
auto foldResult = rewriter.tryFold(concatUser, replacementValues);
76+
if (foldResult.succeeded()) {
77+
if (!replacementValues.empty()) {
78+
rewriter.replaceOp(concatUser, replacementValues);
79+
}
80+
return success();
81+
}
82+
}
83+
}
84+
85+
if (!llvm::all_equal(concatOp->getOperands())) {
86+
return rewriter.notifyMatchFailure(
87+
concatOp, "Requires all operands to be the same");
88+
}
89+
const auto concatType = dyn_cast<ShapedType>(concatOp.getType());
90+
if (!concatType || !concatType.hasRank()) {
91+
return rewriter.notifyMatchFailure(concatOp,
92+
"Requires concat to be ranked");
93+
}
94+
SmallVector<int64_t> multiplies(concatType.getRank(), 1);
95+
multiplies[concatOp.getAxis()] = concatOp->getNumOperands();
96+
auto tileOp = rewriter.createOrFold<tosa::TileOp>(
97+
concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0),
98+
multiplies);
99+
rewriter.replaceOp(concatOp, {tileOp});
100+
return success();
101+
}
102+
};
103+
63104
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
64105
MLIRContext *context) {
65106
results.add<ConcatOptimization>(context);
107+
results.add<SelfConcatToTile>(context);
66108
}
67109

68110
struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> {

mlir/test/Dialect/Tosa/fold_concats.mlir

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
55
return %0 : tensor<1x2x7x7xf32>
66
}
77

8-
// CHECK-LABEL: func.func @single_concat(
9-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
10-
// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
11-
// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32>
8+
// CHECK-LABEL: func.func @single_concat
9+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> {
10+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 2, 1, 1>} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
11+
// CHECK: return [[VAR_0_]] : tensor<1x2x7x7xf32>
1212
// CHECK: }
1313

1414
// -----
@@ -19,11 +19,11 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf
1919
return %1 : tensor<2x2x7x7xf32>
2020
}
2121

22-
// CHECK-LABEL: func.func @concat_different_axis(
23-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
24-
// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
25-
// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
26-
// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32>
22+
// CHECK-LABEL: func.func @concat_different_axis
23+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
24+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 2, 1, 1>} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
25+
// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array<i64: 2, 1, 1, 1>} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
26+
// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32>
2727
// CHECK: }
2828

2929
// -----
@@ -84,10 +84,9 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
8484
return %2 : tensor<1x4x8x8xf32>
8585
}
8686

87-
// CHECK-LABEL: func.func @partially_foldable(
88-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>,
89-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
90-
// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
91-
// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
92-
// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
87+
// CHECK-LABEL: func.func @partially_foldable
88+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> {
89+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_1_]] {multiples = array<i64: 1, 1, 2, 1>} : (tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32>
90+
// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_0_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
91+
// CHECK: return [[VAR_1_]] : tensor<1x4x8x8xf32>
9392
// CHECK: }

0 commit comments

Comments
 (0)