Skip to content

Commit 51b87e7

Browse files
author
Samnour2
committed
Add: prevent onnx to tosa legalization when the size (lhs most parameter) is not constant
1 parent b654c07 commit 51b87e7

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/Conversion/ONNXToTOSA/Tensor/Resize.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct ScaleHelper {
3131
ScaleHelper(
3232
int64_t numerator, int64_t denominator, int64_t offset, int64_t border)
3333
: numerator(numerator), denominator(denominator), offset(offset),
34-
border(border){};
34+
border(border) {};
3535
int64_t numerator, denominator, offset, border;
3636
};
3737

@@ -203,6 +203,16 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern {
203203
resizeOp, "Only static sized tensors are supported.");
204204
}
205205

206+
Value sizesValue = resizeOp.getSizes();
207+
if (!isNoneValue(sizesValue)) {
208+
mlir::ElementsAttr sizesAttr =
209+
getElementAttributeFromONNXValue(sizesValue);
210+
if (!sizesAttr) {
211+
return rewriter.notifyMatchFailure(
212+
resizeOp, "Sizes must be a constant tensor for static inputs.");
213+
}
214+
}
215+
206216
auto elementType = inputType.getElementType();
207217
if (!(isa<FloatType>(elementType) || isTOSAInt(elementType))) {
208218
return rewriter.notifyMatchFailure(

test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,26 @@ func.func @test_resize_cubic_disallowed(%arg0: tensor<1x1x2x4xf32>) -> tensor<1x
252252
// CHECK-LABEL: func.func @test_resize_cubic_disallowed
253253
// CHECK-LABEL: onnx.Resize
254254
}
255+
256+
// -----
257+
258+
func.func @test_resize_size_constant_disallowed(%arg0: tensor<1x1x2x4xf32>, %arg1: tensor<4xi64>, %arg2: tensor<4xi64>) -> tensor<1x1x2x8xf32> {
259+
%0 = "onnx.NoValue"() {value} : () -> none
260+
%2 = "onnx.Add"(%arg1, %arg2) {} : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
261+
%3 = "onnx.Resize"(%arg0, %0, %0, %2) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "linear", nearest_mode = "floor"} : (tensor<1x1x2x4xf32>, none, none, tensor<4xi64>) -> tensor<1x1x2x8xf32>
262+
return %3 : tensor<1x1x2x8xf32>
263+
// CHECK-LABEL: func.func @test_resize_size_constant_disallowed
264+
// CHECK-LABEL: onnx.Resize
265+
}
266+
267+
// -----
268+
269+
func.func @test_resize_size_constant_allowed(%arg0: tensor<1x1x2x4xf32>, %arg1: tensor<4xi64>, %arg2: tensor<4xi64>) -> tensor<1x1x2x8xf32> {
270+
%0 = "onnx.NoValue"() {value} : () -> none
271+
%1 = "onnx.Constant"() {value = dense_resource<__elided__> : tensor<4xi64>} : () -> tensor<4xi64>
272+
%2 = "onnx.Resize"(%arg0, %0, %0, %1) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "linear", nearest_mode = "floor"} : (tensor<1x1x2x4xf32>, none, none, tensor<4xi64>) -> tensor<1x1x2x8xf32>
273+
return %2 : tensor<1x1x2x8xf32>
274+
// CHECK-LABEL: func.func @test_resize_size_constant_allowed
275+
// CHECK-NOT: onnx.Resize
276+
}
277+

0 commit comments

Comments
 (0)