Skip to content

Commit 143e062

Browse files
committed
Add half pixel transformation to resize
1 parent 64f4dd5 commit 143e062

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
6565
quantize_and_dequantize = tf.quantization.quantize_and_dequantize
6666
resize_bilinear = tf.compat.v1.image.resize_bilinear
67+
resize_bilinear_v2 = tf.compat.v2.image.resize
6768
is_nan = tf.math.is_nan
6869
is_inf = tf.math.is_inf
6970
floormod = tf.math.floormod
@@ -81,6 +82,7 @@
8182
quantize_and_dequantize = tf.compat.v1.quantization.quantize_and_dequantize
8283
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
8384
resize_bilinear = tf.compat.v1.image.resize_bilinear
85+
resize_bilinear_v2 = tf.compat.v2.image.resize
8486
is_nan = tf.math.is_nan
8587
is_inf = tf.math.is_inf
8688
floormod = tf.floormod
@@ -1993,6 +1995,16 @@ def func(x, x_new_size_):
19931995
return tf.identity(x_, name=_TFOUTPUT)
19941996
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
19951997

1998+
@check_opset_min_version(11, "resize_bilinear_v2")
1999+
def test_resize_bilinear_v2_with_non_const(self):
2000+
x_shape = [3, 10, 8, 5]
2001+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
2002+
x_new_size = np.array([20, 16]).astype(np.int32)
2003+
def func(x, x_new_size_):
2004+
x_ = resize_bilinear_v2(x, x_new_size_)
2005+
return tf.identity(x_, name=_TFOUTPUT)
2006+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
2007+
19962008
@check_opset_min_version(10, "resize scale can less than 1")
19972009
def test_resize_nearest_neighbor2(self):
19982010
x_shape = [1, 300, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,12 @@ def version_11(cls, ctx, node, **kwargs):
736736
const_empty_float.output[0],
737737
concat_shape.output[0]
738738
]
739+
transformation_mode = "asymmetric"
740+
if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i:
741+
transformation_mode = "half_pixel"
739742
resize = ctx.make_node("Resize", resize_inputs,
740743
attr={"mode": mode, "nearest_mode": "floor",
741-
"coordinate_transformation_mode": "asymmetric"})
744+
"coordinate_transformation_mode": transformation_mode})
742745
shapes = node.output_shapes
743746
dtypes = node.output_dtypes
744747
ctx.remove_node(node.name)

0 commit comments

Comments
 (0)