Skip to content

Commit 0aad4d9

Browse files
committed
Add canonicalization pattern for tile -> slice to minimize the tile multipliers
1 parent eb3aed5 commit 0aad4d9

File tree

2 files changed

+88
-9
lines changed

2 files changed

+88
-9
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
687687
sliceOp, "degenerate slice with zero sized dim in output");
688688
}
689689
sliceStart[axis] -= droppedConcatInputSize;
690-
auto newConcat = rewriter.create<tosa::ConcatOp>(concatOp->getLoc(),
691-
requiredConcatInputs, axis);
690+
auto newConcat = rewriter.create<tosa::ConcatOp>(
691+
concatOp->getLoc(), requiredConcatInputs, axis);
692692
auto newSlice = rewriter.create<tosa::SliceOp>(
693693
sliceOp->getLoc(), sliceOp.getType(), newConcat,
694694
rewriter.getDenseI64ArrayAttr(sliceStart),
@@ -698,9 +698,75 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
698698
}
699699
};
700700

701+
/// This patterns adjust the multipliers of a tile followed by a slice to only
702+
/// tile as much data as it is required by the slice
703+
struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
704+
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
705+
706+
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
707+
PatternRewriter &rewriter) const override {
708+
Value sliceInput = sliceOp.getInput1();
709+
auto tileOp = sliceInput.getDefiningOp<tosa::TileOp>();
710+
if (!tileOp)
711+
return rewriter.notifyMatchFailure(sliceOp,
712+
"slice input must be tile operation");
713+
if (!tileOp->hasOneUse())
714+
return rewriter.notifyMatchFailure(
715+
sliceOp, "preceding tile must have a single use"); // Do not insert
716+
// additional tiles
717+
718+
const auto tileOpInputType =
719+
dyn_cast<RankedTensorType>(tileOp->getOperand(0).getType());
720+
if (!tileOpInputType || !tileOpInputType.hasStaticShape())
721+
return rewriter.notifyMatchFailure(
722+
sliceOp, "input to preceding tile op must be a static ranked tensor");
723+
llvm::SmallVector<int64_t> requiredMultipliers;
724+
llvm::SmallVector<int64_t> newTileStarts;
725+
requiredMultipliers.reserve(tileOpInputType.getRank());
726+
newTileStarts.reserve(tileOpInputType.getRank());
727+
for (auto [axis, sliceStart, sliceSize] :
728+
llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) {
729+
if (sliceSize <= 0) {
730+
return rewriter.notifyMatchFailure(
731+
sliceOp, "degenerate slice with zero sized dim");
732+
}
733+
const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis);
734+
const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
735+
const int64_t sliceSizeInFirstTile =
736+
std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
737+
assert(sliceSizeInFirstTile > 0);
738+
const int64_t requiredMultiplierWithoutFirstTile =
739+
llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize);
740+
const int64_t requiredMultiplier =
741+
requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0);
742+
assert(requiredMultiplier <= tileOp.getMultiples()[axis]);
743+
requiredMultipliers.push_back(requiredMultiplier);
744+
newTileStarts.push_back(sliceOffsetInNewFirstTile);
745+
}
746+
if (requiredMultipliers == tileOp.getMultiples())
747+
return rewriter.notifyMatchFailure(
748+
sliceOp, "could not reduce multipliers in preceding tile");
749+
750+
llvm::SmallVector<int64_t> newTileShape(tileOpInputType.getShape());
751+
for (auto [newShape, multiplier] :
752+
llvm::zip_equal(newTileShape, requiredMultipliers)) {
753+
newShape *= multiplier;
754+
}
755+
auto newTile = rewriter.create<tosa::TileOp>(
756+
tileOp->getLoc(), tileOpInputType.clone(newTileShape),
757+
tileOp->getOperand(0), requiredMultipliers);
758+
auto newSlice = rewriter.create<tosa::SliceOp>(
759+
sliceOp->getLoc(), sliceOp.getType(), newTile,
760+
rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr());
761+
rewriter.replaceOp(sliceOp, newSlice);
762+
return success();
763+
}
764+
};
765+
701766
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
702767
MLIRContext *context) {
703768
results.add<ConcatSliceOptimization>(context);
769+
results.add<TileSliceOptimization>(context);
704770
}
705771

706772
struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -896,14 +896,27 @@ func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %ar
896896
// -----
897897

898898
// CHECK-LABEL: func.func @canonicalize_tile_slice
899-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> {
900-
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 10, 2, 2, 3>} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32>
901-
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 120, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 8>} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32>
902-
// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32>
903-
func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> {
899+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> {
900+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 10, 2, 2, 3, 1>} : (tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x24x20x30x10xf32>
901+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 120, 12, 10, 16, 5>, start = array<i64: 0, 0, 1, 1, 8, 1>} : (tensor<1x120x24x20x30x10xf32>) -> tensor<1x120x12x10x16x5xf32>
902+
// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16x5xf32>
903+
func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> {
904+
%0 = tosa.tile %arg0 {multiples = array<i64: 10, 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32>
905+
%1 = tosa.slice %0 {size = array<i64: 1, 120, 12, 10, 16, 5>, start = array<i64: 0, 0, 1, 1, 18, 1>} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x120x12x10x16x5xf32>
906+
return %1 : tensor<1x120x12x10x16x5xf32>
907+
}
908+
909+
// -----
910+
911+
// CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim
912+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> {
913+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
914+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 0, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32>
915+
// CHECK: return [[VAR_1_]] : tensor<1x0x12x10x16xf32>
916+
func.func @canonicalize_tile_slice_zero_dim(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> {
904917
%0 = tosa.tile %arg0 {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
905-
%1 = tosa.slice %0 {size = array<i64: 1, 120, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32>
906-
return %1 : tensor<1x120x12x10x16xf32>
918+
%1 = tosa.slice %0 {size = array<i64: 1, 0, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32>
919+
return %1 : tensor<1x0x12x10x16xf32>
907920
}
908921

909922
// -----

0 commit comments

Comments
 (0)