Skip to content

Commit cec8368

Browse files
authored
Merge pull request #346 from zhijxu-MS/push_branch
optimize upsample's conversion logic
2 parents 972adea + 395c657 commit cec8368

File tree

5 files changed

+63
-17
lines changed

5 files changed

+63
-17
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
150150
if check_shape:
151151
self.assertEqual(expected_val.shape, actual_val.shape)
152152

153+
return g
154+
153155
def save_onnx_model(self, model_proto, feed_dict, postfix=""):
154156
target_path = utils.save_onnx_model(self.test_data_directory, self._testMethodName + postfix, feed_dict,
155157
model_proto, include_test_data=self.config.is_debug_mode,

tests/common.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import sys
99
import unittest
10+
from collections import defaultdict
1011

1112
from distutils.version import LooseVersion
1213
from tf2onnx import utils
@@ -15,7 +16,7 @@
1516
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1617
"check_tf_min_version", "skip_tf_versions",
1718
"check_opset_min_version", "check_target", "skip_onnxruntime_backend", "skip_caffe2_backend",
18-
"check_onnxruntime_incompatibility"]
19+
"check_onnxruntime_incompatibility", "validate_const_node", "group_nodes_by_type"]
1920

2021

2122
# pylint: disable=missing-docstring
@@ -209,3 +210,17 @@ def check_onnxruntime_incompatibility(op):
209210

210211
reason = "{} is not supported by onnxruntime before opset {}".format(op, support_since[op])
211212
return unittest.skipIf(True, reason)
213+
214+
215+
def validate_const_node(node, expected_val):
216+
if node.is_const():
217+
node_val = node.get_tensor_value()
218+
return node_val == expected_val
219+
return False
220+
221+
222+
def group_nodes_by_type(graph):
223+
res = defaultdict(list)
224+
for node in graph.get_nodes():
225+
res[node.type].append(node)
226+
return res

tests/test_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class BackendTests(Tf2OnnxBackendTestBase):
9292
def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
9393
kwargs["convert_var_to_const"] = False
9494
kwargs["constant_fold"] = False
95-
self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
95+
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
9696

9797
def _test_expand_dims(self, idx):
9898
tf.reset_default_graph()
@@ -1276,7 +1276,10 @@ def test_resize_nearest_neighbor(self):
12761276
x_new_size_ = tf.constant(x_new_size)
12771277
x_ = tf.image.resize_nearest_neighbor(x, x_new_size_)
12781278
_ = tf.identity(x_, name=_TFOUTPUT)
1279-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1279+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1280+
if self.config.opset >= 9:
1281+
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1282+
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
12801283

12811284
@check_opset_min_version(9, "resize_nearest_neighbor")
12821285
def test_resize_nearest_neighbor_with_non_const(self):
@@ -1301,7 +1304,10 @@ def test_resize_bilinear(self):
13011304
x_new_size_ = tf.constant(x_new_size)
13021305
x_ = tf.image.resize_bilinear(x, x_new_size_)
13031306
_ = tf.identity(x_, name=_TFOUTPUT)
1304-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1307+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1308+
if self.config.opset >= 9:
1309+
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1310+
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
13051311

13061312
@check_opset_min_version(9, "resize_bilinear")
13071313
def test_resize_bilinear_with_non_const(self):

tf2onnx/graph.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,19 @@ def is_target(self, *names):
373373
"""Return True if target platform contains any name."""
374374
return any(name in self._target for name in names)
375375

376-
def make_const(self, name, np_val, skip_conversion=False):
376+
def make_const(self, name, np_val, skip_conversion=False, raw=True):
377377
"""Make a new constant in the graph.
378378
Args:
379379
name: const node name, must be unique.
380380
np_val: value of type numpy ndarray.
381381
skip_conversion: bool, indicate whether this created node would be mapped during conversion.
382+
raw: whether to store data at field of raw_data or the specific field according to its dtype
382383
"""
383-
onnx_tensor = numpy_helper.from_array(np_val, name)
384+
if raw:
385+
onnx_tensor = numpy_helper.from_array(np_val, name)
386+
else:
387+
onnx_tensor = helper.make_tensor(name, utils.map_numpy_to_onnx_dtype(np_val.dtype),
388+
np_val.shape, np_val, raw=False)
384389
node = self.make_node("Const", [], outputs=[name], name=name, attr={"value": onnx_tensor},
385390
skip_conversion=skip_conversion)
386391
self.set_shape(name, np_val.shape)
@@ -732,8 +737,12 @@ def make_graph(self, doc, graph_name="tf2onnx"):
732737
# create initializers for constant nodes
733738
const_ops = [op for op in const_ops if op not in placeholder_default_const_ops]
734739
for op in const_ops:
735-
const_val = op.get_tensor_value(as_list=False)
736-
tensor = numpy_helper.from_array(const_val, op.output[0])
740+
# not to use numpy_helper.from_array to create a new tensor
741+
# because sometimes onnx will have a bug that only check the tensor data in specific field
742+
# such as at upsample it only checks the float_data field.
743+
t = op.get_attr("value")
744+
tensor = helper.get_attribute_value(t)
745+
tensor.name = op.output[0]
737746
initializers.append(tensor)
738747

739748
# create input_tensor_values

tf2onnx/tfonnx.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,18 +1056,32 @@ def upsample_op9(ctx, node, name, args):
10561056
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
10571057
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
10581058
# wants the input to be NHWC - adjust target_shape to this.
1059-
ori_shape = ctx.make_node("Shape", [node.input[0]])
1060-
ori_shape_hw = ctx.make_node("Slice", ori_shape.output, {"axes": [0], "starts": [1], "ends": [3]})
1061-
ori_shape_hw_float = ctx.make_node("Cast", ori_shape_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
10621059

1063-
target_hw = node.inputs[1]
1064-
target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
1060+
# first create "scales" info for onnx upsample
1061+
# if shape of input and output known then "scale" is calculated statically and set as a const node
1062+
shape = ctx.get_shape(node.input[0])
1063+
if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
1064+
target_shape = node.inputs[1].get_tensor_value()
1065+
n, h, w, c = shape
1066+
nh, nw = target_shape
1067+
# scales is nchw
1068+
# the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
1069+
scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
1070+
scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
1071+
else:
1072+
ori_shape = ctx.make_node("Shape", [node.input[0]])
1073+
ori_shape_hw = ctx.make_node("Slice", ori_shape.output, {"axes": [0], "starts": [1], "ends": [3]})
1074+
ori_shape_hw_float = ctx.make_node("Cast", ori_shape_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
10651075

1066-
scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])
1076+
target_hw = node.inputs[1]
1077+
target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
10671078

1068-
const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
1069-
# scaler is nchw
1070-
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
1079+
scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])
1080+
1081+
const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
1082+
# scales is nchw
1083+
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
1084+
# because onnxruntime only supports to scale the last two dims so transpose is inserted
10711085
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
10721086
upsample = ctx.make_node("Upsample", [input_nchw.output[0], scales.output[0]], attr={"mode": args[0]})
10731087

0 commit comments

Comments
 (0)