Skip to content

Commit 1428a11

Browse files
committed
[TOSA] bug fix infer shape for slice
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly <[email protected]> Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
1 parent 8287831 commit 1428a11

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
842842
MLIRContext *context, ::std::optional<Location> location,
843843
SliceOp::Adaptor adaptor,
844844
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
845-
inferredReturnShapes.push_back(
846-
ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
845+
auto start = adaptor.getStart();
846+
auto size = adaptor.getSize();
847+
848+
// if size[i] is -1, all remaining elements in dimension i are included
849+
// in the slice, similar to TF.
850+
ShapeAdaptor inputShape(adaptor.getInput().getType());
851+
// initialize outputShape to all unknown
852+
SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
853+
if (inputShape.hasRank()) {
854+
for (size_t i = 0; i < size.size(); i++) {
855+
if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
856+
(ShapedType::isDynamic(inputShape.getDimSize(i)) ||
857+
start[i] < inputShape.getDimSize(i))) {
858+
// size[i] is not 0 and not < -1, and start[i] is in valid range
859+
if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
860+
// input shape has unknown dim[i] - only valid if size[i] > 0
861+
if (size[i] > 0) {
862+
outputShape[i] = size[i];
863+
}
864+
} else {
865+
// input shape has known dim[i]
866+
if (size[i] == -1) {
867+
outputShape[i] = inputShape.getDimSize(i) - start[i];
868+
} else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
869+
// start[i] + size[i] is within bound of input shape's dim[i]
870+
outputShape[i] = size[i];
871+
}
872+
}
873+
}
874+
}
875+
} else {
876+
outputShape = convertToMlirShape(size);
877+
}
878+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
847879
return success();
848880
}
849881

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
532532

533533
// -----
534534

535+
// CHECK-LABEL: @test_slice_size_minus_one
536+
func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
537+
// CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
538+
// this checks following
539+
// dim 0: size=-1, input dim=? => inferred output dim is ?
540+
// dim 1: size=-1 => inferred output dim is input_dim - start
541+
// dim 2: size=-1, start=-1 => inferred output dim is ?
542+
// dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
543+
%2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
544+
return
545+
}
546+
547+
// -----
548+
549+
// CHECK-LABEL: @test_slice_size_out_of_bound
550+
func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
551+
// CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
552+
// this checks following
553+
// dim 0: size=0 => inferred output dim is ?
554+
// dim 1: size=-2 => inferred output dim is ?
555+
// dim 3: start+size out of bound because size too big: inferred output dim is ?
556+
// dim 4: size=4, input dim=? => inferred output dim is 4
557+
%2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
558+
return
559+
}
560+
561+
// -----
562+
563+
// CHECK-LABEL: @test_slice_start_out_of_bound
564+
func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
565+
// CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
566+
// this checks following
567+
// dim 0: start=-1 => inferred output dim is ?
568+
// dim 1: start=8 => inferred output dim is ?
569+
// dim 2: start+size out of bound: inferred output dim is ?
570+
// dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
571+
%2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
572+
return
573+
}
574+
575+
// -----
576+
535577
// CHECK-LABEL: @test_slice_dynamic
536578
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
537579
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>

0 commit comments

Comments
 (0)