Skip to content

Commit 70529a5

Browse files
Support for bicubic resize (#1299)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent e3ea89c commit 70529a5

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

tests/test_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2437,6 +2437,17 @@ def func(x, x_new_size_):
24372437
return tf.identity(x_, name=_TFOUTPUT)
24382438
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
24392439

2440+
@check_tf_min_version("2.0", "Results are slightly different in tf1")
2441+
@check_opset_min_version(11, "resize bicubic")
2442+
def test_resize_bicubic(self):
2443+
x_shape = [1, 15, 20, 2]
2444+
new_size_val = np.array([30, 40], dtype=np.int32)
2445+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
2446+
def func(x, new_size):
2447+
y = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BICUBIC)
2448+
return tf.identity(y, name=_TFOUTPUT)
2449+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: new_size_val}, rtol=1e-6, atol=1e-5)
2450+
24402451
@check_opset_min_version(10, "resize scale can less than 1")
24412452
def test_resize_nearest_neighbor2(self):
24422453
x_shape = [1, 300, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -978,10 +978,12 @@ def version_13(cls, ctx, node, **kwargs):
978978
cls.any_version_after11(13, ctx, node, **kwargs)
979979

980980

981-
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
981+
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor", "ResizeBicubic"])
982982
class Resize:
983983
@classmethod
984984
def version_7(cls, ctx, node, **kwargs):
985+
utils.make_sure(node.type != "ResizeBicubic", "Opset 11 is required for bicubic interpolation for node %s",
986+
node.name)
985987
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
986988
node.type = "Upsample"
987989
shape = ctx.get_shape(node.input[0])
@@ -1009,7 +1011,16 @@ def version_10(cls, ctx, node, **kwargs):
10091011

10101012
@classmethod
10111013
def version_11(cls, ctx, node, **kwargs):
1012-
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
1014+
cubic_coeff_a = None
1015+
exclude_outside = False
1016+
if node.type == "ResizeBilinear":
1017+
mode = "linear"
1018+
elif node.type == "ResizeBicubic":
1019+
mode = "cubic"
1020+
cubic_coeff_a = -0.5
1021+
exclude_outside = True
1022+
else:
1023+
mode = "nearest"
10131024
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
10141025
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
10151026
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
@@ -1035,9 +1046,11 @@ def version_11(cls, ctx, node, **kwargs):
10351046
nearest_mode = "round_prefer_ceil"
10361047
else:
10371048
transformation_mode = "half_pixel"
1038-
resize = ctx.make_node("Resize", resize_inputs,
1039-
attr={"mode": mode, "nearest_mode": nearest_mode,
1040-
"coordinate_transformation_mode": transformation_mode})
1049+
attr = {"mode": mode, "nearest_mode": nearest_mode, "coordinate_transformation_mode": transformation_mode,
1050+
"exclude_outside": exclude_outside}
1051+
if cubic_coeff_a is not None:
1052+
attr["cubic_coeff_a"] = cubic_coeff_a
1053+
resize = ctx.make_node("Resize", resize_inputs, attr=attr)
10411054
shapes = node.output_shapes
10421055
dtypes = node.output_dtypes
10431056
ctx.remove_node(node.name)
@@ -1050,6 +1063,8 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
10501063
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
10511064
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
10521065
# wants the input to be NHWC - adjust target_shape to this.
1066+
utils.make_sure(node.type != "ResizeBicubic", "Opset 11 is required for bicubic interpolation for node %s",
1067+
node.name)
10531068
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
10541069

10551070
# because onnxruntime only supports to scale the last two dims so transpose is inserted

0 commit comments

Comments
 (0)