diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a32e4ccbed594..3135fbd49bfbe 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2533,16 +2533,26 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents( } // Compute the output shape based on attributes: scale, offset, and border. - outputShape[1] = + const int64_t outputHeight = (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / scaleInt[1]) + 1; - outputShape[2] = + const int64_t outputWidth = (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / scaleInt[3]) + 1; + if (outputHeight < 0 || outputWidth < 0) { + return emitOptionalError( + location, + "calculated output height and width must be non-negative, " + "got height = ", + outputHeight, ", width = ", outputWidth); + } + + outputShape[1] = outputHeight; + outputShape[2] = outputWidth; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 591a3f0acf65d..18409d24fbc1b 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1115,6 +1115,18 @@ func.func @resize_fp_power_of_two_upscale_offsetted(%arg0: tensor<1x50x48x1xf32> // ----- +// CHECK-LABEL: @resize_negative_output_dim +func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) { + %scale = tosa.const_shape { values = dense<[1, 3, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2> + // expected-error@+1 {{calculated output height and width must be non-negative, got height = -5, width = 0}} + %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8> + return +} + +// ----- + // CHECK-LABEL: @if_test_simple func.func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { %a = tosa.log %arg0 : (tensor) -> tensor