Skip to content

Commit 0338e1e

Browse files
Fixed half_pixel_centers for resize_nearest_neighbor
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 57ccfb3 commit 0338e1e

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,6 +2101,19 @@ def func(x):
21012101
return tf.identity(x_, name=_TFOUTPUT)
21022102
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
21032103

2104+
@skip_caffe2_backend()
2105+
@check_tf_min_version("1.14")
2106+
@check_opset_min_version(11, "coordinate_transformation_mode attr")
2107+
def test_resize_bilinear_half_pixel_centers(self):
2108+
x_shape = [1, 15, 20, 2]
2109+
x_new_size = [30, 40]
2110+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
2111+
def func(x):
2112+
x_new_size_ = tf.constant(x_new_size)
2113+
x_ = resize_bilinear(x, x_new_size_, half_pixel_centers=True)
2114+
return tf.identity(x_, name=_TFOUTPUT)
2115+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2116+
21042117
@check_opset_min_version(9, "resize_bilinear")
21052118
def test_resize_bilinear_with_non_const(self):
21062119
x_shape = [3, 10, 8, 5]
@@ -2144,6 +2157,18 @@ def func(x):
21442157
return tf.identity(x_, name=_TFOUTPUT)
21452158
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
21462159

2160+
@check_tf_min_version("1.14")
2161+
@check_opset_min_version(11, "coordinate_transformation_mode attr")
2162+
def test_resize_nearest_neighbor_half_pixel_centers(self):
2163+
x_shape = [1, 10, 20, 2]
2164+
x_new_size = [20, 40]
2165+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
2166+
def func(x):
2167+
x_new_size_ = tf.constant(x_new_size)
2168+
x_ = resize_nearest_neighbor(x, x_new_size_, half_pixel_centers=True)
2169+
return tf.identity(x_, name=_TFOUTPUT)
2170+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2171+
21472172
@check_opset_min_version(9, "fill")
21482173
def test_fill_float32(self):
21492174
x_shape = [1, 15, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,10 @@ def version_11(cls, ctx, node, **kwargs):
903903
if "align_corners" in node.attr and node.attr["align_corners"].i:
904904
transformation_mode = "align_corners"
905905
if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i:
906-
transformation_mode = "half_pixel"
906+
if node.type == "ResizeNearestNeighbor":
907+
transformation_mode = "tf_half_pixel_for_nn"
908+
else:
909+
transformation_mode = "half_pixel"
907910
resize = ctx.make_node("Resize", resize_inputs,
908911
attr={"mode": mode, "nearest_mode": "floor",
909912
"coordinate_transformation_mode": transformation_mode})

0 commit comments

Comments
 (0)