Skip to content

Commit 620b8b9

Browse files
Merge pull request #1146 from onnx/tom/FixResizeNN
Fixed half_pixel_centers for resize_nearest_neighbor
2 parents 0e2ea55 + 0338e1e commit 620b8b9

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,19 @@ def func(x):
21172117
return tf.identity(x_, name=_TFOUTPUT)
21182118
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
21192119

2120+
@skip_caffe2_backend()
2121+
@check_tf_min_version("1.14")
2122+
@check_opset_min_version(11, "coordinate_transformation_mode attr")
2123+
def test_resize_bilinear_half_pixel_centers(self):
2124+
x_shape = [1, 15, 20, 2]
2125+
x_new_size = [30, 40]
2126+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
2127+
def func(x):
2128+
x_new_size_ = tf.constant(x_new_size)
2129+
x_ = resize_bilinear(x, x_new_size_, half_pixel_centers=True)
2130+
return tf.identity(x_, name=_TFOUTPUT)
2131+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2132+
21202133
@check_opset_min_version(9, "resize_bilinear")
21212134
def test_resize_bilinear_with_non_const(self):
21222135
x_shape = [3, 10, 8, 5]
@@ -2160,6 +2173,18 @@ def func(x):
21602173
return tf.identity(x_, name=_TFOUTPUT)
21612174
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
21622175

2176+
@check_tf_min_version("1.14")
2177+
@check_opset_min_version(11, "coordinate_transformation_mode attr")
2178+
def test_resize_nearest_neighbor_half_pixel_centers(self):
2179+
x_shape = [1, 10, 20, 2]
2180+
x_new_size = [20, 40]
2181+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
2182+
def func(x):
2183+
x_new_size_ = tf.constant(x_new_size)
2184+
x_ = resize_nearest_neighbor(x, x_new_size_, half_pixel_centers=True)
2185+
return tf.identity(x_, name=_TFOUTPUT)
2186+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2187+
21632188
@check_opset_min_version(9, "fill")
21642189
def test_fill_float32(self):
21652190
x_shape = [1, 15, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,10 @@ def version_11(cls, ctx, node, **kwargs):
903903
if "align_corners" in node.attr and node.attr["align_corners"].i:
904904
transformation_mode = "align_corners"
905905
if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i:
906-
transformation_mode = "half_pixel"
906+
if node.type == "ResizeNearestNeighbor":
907+
transformation_mode = "tf_half_pixel_for_nn"
908+
else:
909+
transformation_mode = "half_pixel"
907910
resize = ctx.make_node("Resize", resize_inputs,
908911
attr={"mode": mode, "nearest_mode": "floor",
909912
"coordinate_transformation_mode": transformation_mode})

0 commit comments

Comments
 (0)