Skip to content

Commit 358df15

Browse files
committed
Fold consecutive tosa.tiles
1 parent 55654e9 commit 358df15

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
13621362
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
13631363
if (allOnes && getInput1().getType() == getType())
13641364
return getInput1();
1365+
1366+
if (auto inputTile = getInput1().getDefiningOp<TileOp>()) {
1367+
if (!inputTile->hasOneUse()) {
1368+
return {};
1369+
}
1370+
llvm::SmallVector<int64_t> newMultiplies{getMultiples()};
1371+
for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) {
1372+
newMultiplies[idx] *= multiplier;
1373+
}
1374+
setMultiples(newMultiplies);
1375+
setOperand(inputTile->getOperand(0));
1376+
getOperation()->setLoc(
1377+
FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()}));
1378+
return getResult();
1379+
}
13651380
return {};
13661381
}
13671382

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,31 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
691691

692692
// -----
693693

694+
// CHECK-LABEL: func.func @tile_fold_consecutive
695+
func.func @tile_fold_consecutive(%arg0: tensor<3x4xf32>) -> tensor<6x16xf32> {
696+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<6x16xf32> {
697+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 2, 4>} : (tensor<3x4xf32>) -> tensor<6x16xf32>
698+
// CHECK: return [[VAR_0_]] : tensor<6x16xf32>
699+
%0 = tosa.tile %arg0 { multiples = array<i64: 1, 2> }: (tensor<3x4xf32>) -> tensor<3x8xf32>
700+
%1 = tosa.tile %0 { multiples = array<i64: 2, 2> }: (tensor<3x8xf32>) -> tensor<6x16xf32>
701+
return %1 : tensor<6x16xf32>
702+
}
703+
704+
// -----
705+
706+
// CHECK-LABEL: func.func @tile_no_fold_consecutive_multi_use
707+
func.func @tile_no_fold_consecutive_multi_use(%arg0: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) {
708+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) {
709+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 2>} : (tensor<3x4xf32>) -> tensor<3x8xf32>
710+
// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array<i64: 2, 2>} : (tensor<3x8xf32>) -> tensor<6x16xf32>
711+
// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<3x8xf32>, tensor<6x16xf32>
712+
%0 = tosa.tile %arg0 { multiples = array<i64: 1, 2> }: (tensor<3x4xf32>) -> tensor<3x8xf32>
713+
%1 = tosa.tile %0 { multiples = array<i64: 2, 2> }: (tensor<3x8xf32>) -> tensor<6x16xf32>
714+
return %0, %1 : tensor<3x8xf32>, tensor<6x16xf32>
715+
}
716+
717+
// -----
718+
694719
// CHECK-LABEL: @tile_nofold
695720
func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
696721
// CHECK: tosa.tile

mlir/test/Dialect/Tosa/fold_concats.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf
2121

2222
// CHECK-LABEL: func.func @concat_different_axis
2323
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> {
24-
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 2, 1, 1>} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
25-
// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array<i64: 2, 1, 1, 1>} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32>
26-
// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32>
24+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 2, 2, 1, 1>} : (tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32>
25+
// CHECK: return [[VAR_0_]] : tensor<2x2x7x7xf32>
2726
// CHECK: }
2827

2928
// -----

0 commit comments

Comments
 (0)