Skip to content

Commit eb3aed5

Browse files
committed
Extend concat -> slice canonicalization to remove concat inputs if possible
1 parent 358df15 commit eb3aed5

File tree

2 files changed

+126
-22
lines changed

2 files changed

+126
-22
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -653,35 +653,47 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
653653

654654
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
655655
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
656-
657-
// Validate slice on the concatenated axis. Slicing along this
658-
// axis should span only one of the inputs to the concatenate
659-
// operation.
660-
std::optional<Value> replaceWithSlice;
656+
llvm::SmallVector<Value> requiredConcatInputs;
657+
int64_t processedOriginalConcatInputSize = 0;
658+
int64_t droppedConcatInputSize = 0;
661659
for (auto input : inputs) {
662-
auto inputType = dyn_cast<RankedTensorType>(input.getType());
660+
const auto inputType = dyn_cast<RankedTensorType>(input.getType());
663661
if (!inputType || !inputType.hasStaticShape())
664662
return rewriter.notifyMatchFailure(
665663
sliceOp, "concat input must be a static ranked tensor");
666-
667-
if (sliceStart[axis] >= 0 &&
668-
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
669-
replaceWithSlice = rewriter
670-
.create<tosa::SliceOp>(
671-
sliceOp.getLoc(), sliceOp.getType(), input,
672-
rewriter.getDenseI64ArrayAttr(sliceStart),
673-
rewriter.getDenseI64ArrayAttr(sliceSize))
674-
.getResult();
675-
break;
664+
if (processedOriginalConcatInputSize <
665+
(sliceStart[axis] + sliceSize[axis]) &&
666+
(processedOriginalConcatInputSize + inputType.getDimSize(axis)) >
667+
sliceStart[axis]) {
668+
if (requiredConcatInputs.empty()) {
669+
droppedConcatInputSize = processedOriginalConcatInputSize;
670+
}
671+
requiredConcatInputs.push_back(input);
676672
}
677-
sliceStart[axis] -= inputType.getDimSize(axis);
673+
processedOriginalConcatInputSize += inputType.getDimSize(axis);
678674
}
679-
680-
if (!replaceWithSlice)
675+
if (requiredConcatInputs.size() == concatOp->getNumOperands()) {
681676
return rewriter.notifyMatchFailure(
682-
sliceOp, "corresponding concat input not found for slice");
683-
684-
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
677+
sliceOp, "Could not reduce number of inputs to preceding concat");
678+
}
679+
if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) {
680+
return rewriter.notifyMatchFailure(
681+
sliceOp,
682+
"Preceding concat must have a single use"); // Do not introduce new
683+
// concats
684+
}
685+
if (requiredConcatInputs.empty()) {
686+
return rewriter.notifyMatchFailure(
687+
sliceOp, "degenerate slice with zero sized dim in output");
688+
}
689+
sliceStart[axis] -= droppedConcatInputSize;
690+
auto newConcat = rewriter.create<tosa::ConcatOp>(concatOp->getLoc(),
691+
requiredConcatInputs, axis);
692+
auto newSlice = rewriter.create<tosa::SliceOp>(
693+
sliceOp->getLoc(), sliceOp.getType(), newConcat,
694+
rewriter.getDenseI64ArrayAttr(sliceStart),
695+
rewriter.getDenseI64ArrayAttr(sliceSize));
696+
rewriter.replaceOp(sliceOp, newSlice);
685697
return success();
686698
}
687699
};

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,98 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
829829

830830
// -----
831831

832+
// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap
833+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
834+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32>
835+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32>
836+
// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32>
837+
func.func @canonicalize_concat_slice_partial_concat_start_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
838+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
839+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32>
840+
return %1 : tensor<1x12x12x2xf32>
841+
}
842+
843+
// -----
844+
845+
// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap
846+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
847+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32>
848+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32>
849+
// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32>
850+
func.func @canonicalize_concat_slice_partial_concat_end_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> {
851+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
852+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 3>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32>
853+
return %1 : tensor<1x12x12x2xf32>
854+
}
855+
856+
// -----
857+
858+
// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap
859+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> {
860+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
861+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 4>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32>
862+
// CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32>
863+
func.func @canonicalize_concat_slice_partial_concat_all_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> {
864+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
865+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 4>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32>
866+
return %1 : tensor<1x12x12x4xf32>
867+
}
868+
869+
// -----
870+
871+
// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use
872+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) {
873+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
874+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32>
875+
// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>
876+
func.func @canonicalize_concat_slice_partial_concat_multi_use(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) {
877+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
878+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 2>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32>
879+
return %0, %1 : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>
880+
}
881+
882+
// -----
883+
884+
// CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim
885+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> {
886+
// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
887+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 0>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32>
888+
// CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32>
889+
// CHECK: }
890+
func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> {
891+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32>
892+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 0>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32>
893+
return %1 : tensor<1x12x12x0xf32>
894+
}
895+
896+
// -----
897+
898+
// CHECK-LABEL: func.func @canonicalize_tile_slice
899+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> {
900+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 1, 10, 2, 2, 3>} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32>
901+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 120, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 8>} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32>
902+
// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32>
903+
func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> {
904+
%0 = tosa.tile %arg0 {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
905+
%1 = tosa.slice %0 {size = array<i64: 1, 120, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32>
906+
return %1 : tensor<1x120x12x10x16xf32>
907+
}
908+
909+
// -----
910+
911+
// CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output
912+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) {
913+
// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
914+
// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array<i64: 1, 12, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32>
915+
// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>
916+
func.func @canonicalize_tile_slice_multi_output(%arg0 : tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) {
917+
%0 = tosa.tile %arg0 {multiples = array<i64: 10, 10, 10, 10, 10>} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32>
918+
%1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 10, 16>, start = array<i64: 0, 0, 1, 1, 18>} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32>
919+
return %0, %1 : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>
920+
}
921+
922+
// -----
923+
832924
// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal
833925
func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> {
834926
// CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>

0 commit comments

Comments
 (0)