Skip to content

Commit 315971d

Browse files
authored
Merge pull request #46 from onnx/gs/onnx-1.2
add support for tf.image.resize_bilinear, fix issue in tf.image.resize_nearest_neighbor
2 parents bd92ea4 + 85e436c commit 315971d

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def run(self):
7474
version=VersionInfo.version,
7575
description='Tensorflow to ONNX converter',
7676
setup_requires=['pytest-runner'],
77-
tests_require=['numpy', 'pytest', 'pytest-cov', 'psutil', 'graphviz'],
77+
tests_require=['pytest', 'pytest-cov', 'psutil', 'graphviz', 'pyyaml'],
7878
cmdclass=cmdclass,
7979
packages=find_packages(),
8080
8181
author_email='[email protected]',
8282
url='https://github.com/onnx/tensorflow-onnx',
83-
install_requires=['pyyaml', 'onnx>=1.2']
83+
install_requires=['numpy>=1.14.1', 'onnx>=1.2']
8484
)

tests/test_backend.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -823,10 +823,10 @@ def test_strided_slice2(self):
823823
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
824824
self.assertAllClose(expected, actual)
825825

826-
@unittest.skip
826+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not correctly supported")
827827
def test_resize_nearest_neighbor(self):
828828
# this should work but no runtime I tried supports it.
829-
x_shape = [1, 15, 20, 3]
829+
x_shape = [1, 15, 20, 2]
830830
x_new_size = [30, 40]
831831
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
832832
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
@@ -836,6 +836,19 @@ def test_resize_nearest_neighbor(self):
836836
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
837837
self.assertAllClose(expected, actual)
838838

839+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not correctly supported")
840+
def test_resize_bilinear(self):
841+
# this should work but no runtime I tried supports it.
842+
x_shape = [1, 15, 20, 2]
843+
x_new_size = [30, 40]
844+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
845+
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
846+
x_new_size_ = tf.constant(x_new_size)
847+
x_ = tf.image.resize_bilinear(x, x_new_size_)
848+
output = tf.identity(x_, name=_TFOUTPUT)
849+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
850+
self.assertAllClose(expected, actual)
851+
839852
@unittest.skip
840853
def test_fill(self):
841854
# no official fill op in onnx

tf2onnx/tfonnx.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -749,18 +749,22 @@ def lrn_op(ctx, node, name, args):
749749
return node
750750

751751

752-
def upsample_op(ctx, node, name, args):
753-
node.type = "Upsample"
752+
def upsample_op7(ctx, node, name, args):
753+
mode = args[0]
754754
shape = ctx.get_shape(node.input[0])
755755
target_shape = node.inputs[1].get_tensor_value()
756756
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
757757
# wants the input to be NHWC - adjust target_shape to this.
758758
n, h, w, c = shape
759759
nh, nw = target_shape
760-
scaler = [float(n), float(nh) / h, float(nw) / w, float(c)]
760+
# scaler = [float(n), float(nh) / h, float(nw) / w, float(c)]
761+
scaler = [float(nh) / h, float(nw) / w]
761762
node.set_attr("scales", scaler)
763+
node.set_attr("mode", mode)
762764
ctx.remove_input(node, node.input[1])
763-
return node
765+
node.data_format = "NHWC"
766+
nodes = conv_convert_inputs(ctx, node, with_kernel=False)
767+
return nodes
764768

765769

766770
def multinomial_op(ctx, node, name, args):
@@ -897,7 +901,8 @@ def spacetodepth_op(ctx, node, name, args):
897901

898902
_OPSET_7 = {
899903
"Tile": (tile_op7, []),
900-
"ResizeNearestNeighbor": (upsample_op, []),
904+
"ResizeNearestNeighbor": (upsample_op7, ["Upsample", "nearest"]),
905+
"ResizeBilinear": (upsample_op7, ["Upsample", "linear"]),
901906
"BiasAdd": (biasadd_op7, []),
902907
"BiasAddV1": (biasadd_op7, []),
903908
"Add": (broadcast_op7, []),

0 commit comments

Comments
 (0)