Skip to content

Commit 43c01e5

Browse files
authored
Merge pull request #932 from peri044/resize
Add half pixel transformation to resize bilinear op
2 parents 64f4dd5 + 9928427 commit 43c01e5

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
6565
quantize_and_dequantize = tf.quantization.quantize_and_dequantize
6666
resize_bilinear = tf.compat.v1.image.resize_bilinear
67+
resize_bilinear_v2 = tf.compat.v2.image.resize
6768
is_nan = tf.math.is_nan
6869
is_inf = tf.math.is_inf
6970
floormod = tf.math.floormod
@@ -81,6 +82,7 @@
8182
quantize_and_dequantize = tf.compat.v1.quantization.quantize_and_dequantize
8283
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
8384
resize_bilinear = tf.compat.v1.image.resize_bilinear
85+
resize_bilinear_v2 = tf.compat.v2.image.resize
8486
is_nan = tf.math.is_nan
8587
is_inf = tf.math.is_inf
8688
floormod = tf.floormod
@@ -1993,6 +1995,17 @@ def func(x, x_new_size_):
19931995
return tf.identity(x_, name=_TFOUTPUT)
19941996
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
19951997

1998+
@check_tf_min_version("1.14")
1999+
@check_opset_min_version(11, "resize_bilinear_v2")
2000+
def test_resize_bilinear_v2_with_non_const(self):
2001+
x_shape = [3, 10, 8, 5]
2002+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
2003+
x_new_size = np.array([20, 16]).astype(np.int32)
2004+
def func(x, x_new_size_):
2005+
x_ = resize_bilinear_v2(x, x_new_size_)
2006+
return tf.identity(x_, name=_TFOUTPUT)
2007+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
2008+
19962009
@check_opset_min_version(10, "resize scale can less than 1")
19972010
def test_resize_nearest_neighbor2(self):
19982011
x_shape = [1, 300, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,12 @@ def version_11(cls, ctx, node, **kwargs):
736736
const_empty_float.output[0],
737737
concat_shape.output[0]
738738
]
739+
transformation_mode = "asymmetric"
740+
if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i:
741+
transformation_mode = "half_pixel"
739742
resize = ctx.make_node("Resize", resize_inputs,
740743
attr={"mode": mode, "nearest_mode": "floor",
741-
"coordinate_transformation_mode": "asymmetric"})
744+
"coordinate_transformation_mode": transformation_mode})
742745
shapes = node.output_shapes
743746
dtypes = node.output_dtypes
744747
ctx.remove_node(node.name)

0 commit comments

Comments
 (0)