Skip to content

Commit 76866db

Browse files
authored
Merge pull request #62 from onnx/gs/onnx-1.2
- add support for onehot (rank=1 for now) - onnx optset < 8 does define reshape for float only ... insert casts for now - update tests - pass backend for unittest via command line fix for #33
2 parents 8432455 + 5338003 commit 76866db

File tree

5 files changed

+148
-30
lines changed

5 files changed

+148
-30
lines changed

tests/run_pretrained_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,18 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
271271
graph_def = graph_pb2.GraphDef()
272272
with open(model_path, "rb") as f:
273273
graph_def.ParseFromString(f.read())
274+
274275
g = tf.import_graph_def(graph_def, name='')
275276
with tf.Session(graph=g) as sess:
277+
278+
# fix inputs if needed
279+
for k in inputs.keys():
280+
t = sess.graph.get_tensor_by_name(k)
281+
dtype = tf.as_dtype(t.dtype).name
282+
if type != "float32":
283+
v = inputs[k]
284+
inputs[k] = v.astype(dtype)
285+
276286
tf_results = self.run_tensorflow(sess, inputs)
277287
onnx_graph = None
278288
print("\ttensorflow", "OK")

tests/test_backend.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
import argparse
45
import os
6+
import sys
57
import tempfile
68
import unittest
79
from collections import namedtuple
@@ -553,6 +555,16 @@ def test_reshape(self):
553555
self.assertEqual(expected.shape, actual.shape)
554556
self.assertAllClose(expected, actual)
555557

558+
def test_reshape_int(self):
559+
x_val = np.array([1, 2, 3, 4], dtype=np.int32).reshape((2, 2))
560+
x = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
561+
shape = tf.constant([1, 4])
562+
x_ = tf.reshape(x, shape)
563+
output = tf.identity(x_, name=_TFOUTPUT)
564+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
565+
self.assertEqual(expected.shape, actual.shape)
566+
self.assertAllClose(expected, actual)
567+
556568
@unittest.skipIf(OPSET < 5 or BACKEND in ["onnxmsrtnext"], "since opset 5, broken in msrtnext")
557569
def test_reshape_dynamic(self):
558570
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))
@@ -788,13 +800,35 @@ def test_cast(self):
788800
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
789801
self.assertAllClose(expected, actual)
790802

791-
@unittest.skip
792-
def test_onehot(self):
803+
def test_onehot0(self):
793804
# no such op in onnx
794805
x_val = np.array([0, 1, 2], dtype=np.int32)
806+
depth = 5
807+
for axis in [-1, 0, 1]:
808+
tf.reset_default_graph()
809+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
810+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
811+
output = tf.identity(x_, name=_TFOUTPUT)
812+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
813+
self.assertAllClose(expected, actual)
814+
815+
@unittest.skip
816+
def test_onehot1(self):
817+
# no such op in onnx
818+
x_val = np.array([[0, 2], [1, -1]], dtype=np.int32)
795819
depth = 3
796820
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
797-
x_ = tf.one_hot(x, depth, on_value=1, axis=0, off_value=0)
821+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=-1, off_value=0.0, dtype=tf.float32)
822+
output = tf.identity(x_, name=_TFOUTPUT)
823+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
824+
self.assertAllClose(expected, actual)
825+
826+
def test_onehot2(self):
827+
# no such op in onnx
828+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
829+
depth = 20
830+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
831+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=-1, off_value=1.0, dtype=tf.float32)
798832
output = tf.identity(x_, name=_TFOUTPUT)
799833
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
800834
self.assertAllClose(expected, actual)
@@ -854,14 +888,6 @@ def test_unstack_axis(self):
854888
actual, expected = self._run(output, {}, {})
855889
self.assertAllClose(expected, actual)
856890

857-
def test_unstack_axis1(self):
858-
x_val = np.random.randn(10, 3, 4).astype("float32")
859-
x = tf.constant(x_val, dtype=tf.float32)
860-
x_ = tf.unstack(x, axis=1)
861-
output = tf.identity(x_, name=_TFOUTPUT)
862-
actual, expected = self._run(output, {}, {})
863-
self.assertAllClose(expected, actual)
864-
865891
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "Space2Depth not implemented, works on onnxmsrtnext")
866892
def test_space_to_depth(self):
867893
x_val = make_xval([1, 2, 2, 1])
@@ -871,7 +897,6 @@ def test_space_to_depth(self):
871897
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
872898
self.assertAllClose(expected, actual)
873899

874-
875900
@unittest.skipIf(OPSET < 6, "supported since opset 6")
876901
def test_addn(self):
877902
x_val = np.arange(3*2*3).astype("float32")
@@ -935,4 +960,14 @@ def test_fill(self):
935960

936961

937962
if __name__ == "__main__":
963+
parser = argparse.ArgumentParser()
964+
parser.add_argument('--backend', default='caffe2',
965+
choices=["caffe2", "onnxmsrt", "onnxmsrtnext", "onnx-tensorflow"],
966+
help="backend to test against")
967+
parser.add_argument('unittest_args', nargs='*')
968+
969+
args = parser.parse_args()
970+
BACKEND = args.backend
971+
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
972+
sys.argv[1:] = args.unittest_args
938973
unittest.main()

tests/unity.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@ BananaRL:
4242
- value_estimate:0
4343

4444
Basic:
45-
# needs: onehot
46-
disabled: true
4745
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Basic/TFModels/Basic.bytes
4846
model: Basic.bytes
4947
input_get: get_random
5048
inputs:
51-
"vector_observation:0": [1, 1]
49+
"vector_observation:0": [10, 1]
5250
outputs:
5351
- action:0
5452
- action_probs:0

tf2onnx/graph.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ def shape(self):
129129
shape[0] = utils.ONNX_UNKNOWN_DIMENSION
130130
return shape
131131

132+
def get_tensor_type(self):
133+
"""Get the onnx data type of a tensor."""
134+
t = self.get_attr("value")
135+
if t:
136+
t = helper.get_attribute_value(t)
137+
if t:
138+
return utils.ONNX_TO_NUMPY_DTYPE[t.data_type]
139+
return onnx_pb.TensorProto.FLOAT
140+
132141
def get_tensor_value(self):
133142
"""Get value for onnx tensor."""
134143
if not self.is_const():
@@ -478,7 +487,8 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwarg
478487
Returns:
479488
node that was inserted
480489
"""
481-
assert isinstance(input_name, str) and isinstance(op_type, str)
490+
if name is None:
491+
name = utils.make_name(node.name)
482492
new_output = name + ":0"
483493
new_node = Node(helper.make_node(op_type, [input_name], [new_output], name=name, **kwargs), self)
484494
for i, n in enumerate(node.input):
@@ -504,6 +514,14 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
504514
self.replace_all_inputs(self.get_nodes(), output_name, new_output)
505515
return new_node
506516

517+
def find_output_consumers(self, output_name):
518+
"""Find all nodes consuming a given output."""
519+
nodes = []
520+
for node in self.get_nodes():
521+
if output_name in node.input:
522+
nodes.append(node)
523+
return nodes
524+
507525
@staticmethod
508526
def replace_all_inputs(ops, old_input, new_input):
509527
"""Replace all inputs pointing to old_input with new_input."""

tf2onnx/tfonnx.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_CAFFE2]
3030
DEFAULT_TARGET = [TARGET_RS4, TARGET_CAFFE2]
3131

32+
3233
def tensorflow_to_onnx(graph):
3334
"""
3435
Load tensorflow graph into an onnx graph with minimal rewrites so
@@ -119,10 +120,9 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
119120
shape_node.set_attr("value", onnx_tensor)
120121
return [node]
121122
else:
122-
op_name = utils.make_name(node.name)
123-
cast_op = ctx.insert_new_node_on_input(node, "Cast", name, name=op_name)
123+
cast_op = ctx.insert_new_node_on_input(node, "Cast", name)
124124
cast_op.set_attr("to", onnx_pb.TensorProto.INT64)
125-
ctx.copy_shape(name, op_name + ":0")
125+
ctx.copy_shape(name, cast_op.output[0])
126126
return [cast_op, node]
127127

128128
# pylint: disable=W0613,C0111,W0612
@@ -274,8 +274,29 @@ def reshape_op(ctx, node, name, args):
274274

275275

276276
def reshape_op5(ctx, node, name, args):
277+
need_casting = node.dtype in [onnx_pb.TensorProto.INT32,
278+
onnx_pb.TensorProto.INT16,
279+
onnx_pb.TensorProto.INT64]
277280
# onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow.
278-
return _convert_shapenode_to_int64(ctx, node, 1)
281+
nodes = _convert_shapenode_to_int64(ctx, node, 1)
282+
if not need_casting:
283+
# onnx reshape can handle the type - done
284+
return nodes
285+
286+
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
287+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
288+
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
289+
ctx.copy_shape(name, input_cast.output[0])
290+
291+
# if the next node is already a cast we don't need to insert another one
292+
next_nodes = ctx.find_output_consumers(node.output[0])
293+
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
294+
op_name = utils.make_name(node.name)
295+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name)
296+
output_cast.set_attr("to", node.dtype)
297+
ctx.copy_shape(name, output_cast.output[0])
298+
nodes.append(output_cast)
299+
return [input_cast] + nodes
279300

280301

281302
NCHW_TO_NHWC = [0, 2, 3, 1]
@@ -317,8 +338,7 @@ def calc_shape(a, b):
317338
else:
318339
# if input comes from a op, insert transpose op
319340
input_name = node.input[0]
320-
op_name = utils.make_name(node.name)
321-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name, name=op_name)
341+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
322342
transpose.set_attr("perm", NHWC_TO_NCHW)
323343
transpose.inserted_nchw = True
324344
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(input_name), NHWC_TO_NCHW))
@@ -336,9 +356,8 @@ def calc_shape(a, b):
336356
parent.data_format = "NCHW"
337357
else:
338358
# kernel comes from op, insert transpose op
339-
op_name = utils.make_name(node.name)
340359
input_name = node.input[1]
341-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name, name=op_name)
360+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
342361
transpose.set_attr("perm", HWCN_TO_NCHW)
343362
transpose.inserted_nchw = True
344363
ctx.copy_shape(input_name, transpose.output[0])
@@ -349,18 +368,16 @@ def calc_shape(a, b):
349368
if new_kernel_shape:
350369
if ctx.opset < 5:
351370
# old reshape takes new shape as attribute
352-
op_name = utils.make_name(node.name)
353371
input_name = node.input[1]
354-
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name, name=op_name)
372+
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
355373
reshape.set_attr("shape", new_kernel_shape)
356374
ctx.set_shape(reshape.output[0], new_kernel_shape)
357375
else:
358376
# new reshape takes new shape as input[1]
359-
op_name = utils.make_name(node.name)
360377
shape_name = utils.make_name(node.name)
361378
shape_node = ctx.make_const(shape_name, "Const", np.array(new_kernel_shape, dtype=np.int64))
362379
input_name = node.input[1]
363-
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name, name=op_name)
380+
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
364381
reshape.input.append(shape_name)
365382
ctx.set_shape(reshape.output[0], new_kernel_shape)
366383
nodes.append(reshape)
@@ -820,7 +837,7 @@ def minmax_op(ctx, node, name, args):
820837
input_node = node.inputs[i]
821838
dtype = ctx.dtypes[node.input[i]]
822839
zero_name = utils.make_name(input_node.name)
823-
zero_node = ctx.make_const(zero_name, "Const", np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
840+
ctx.make_const(zero_name, "Const", np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
824841
op_name = utils.make_name(input_node.name)
825842
output_name = op_name + ":0"
826843
add_node = Node(helper.make_node("Add", [input_node.output[0], zero_name],
@@ -853,6 +870,7 @@ def pack_op(ctx, node, name, args):
853870
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], output_name)
854871
return [concat] + nodes
855872

873+
856874
def unpack_op(ctx, node, name, args):
857875
# hack to make up for the missing onnx unpack op
858876
axis = node.get_attr("axis").i
@@ -870,6 +888,44 @@ def unpack_op(ctx, node, name, args):
870888
return nodes
871889

872890

891+
def onehot_op(ctx, node, name, args):
892+
# until there is no onehot op in onnx, a workaround using gather from eye
893+
indices_name = node.input[0]
894+
indices_shape = ctx.get_shape(indices_name)
895+
if len(indices_shape) != 1:
896+
# TODO: this works for rank=1 but tensorflow supports more than this.
897+
# Same principle should work but we need to implemtn our own eye.
898+
raise ValueError("onehot op: only rank1 is supported")
899+
axis = node.get_attr("axis")
900+
# axis becomes axis for gather
901+
node.set_attr("axis", 0)
902+
depth = node.inputs[1].get_tensor_value()[0]
903+
on = node.inputs[2].get_tensor_value()[0]
904+
off = node.inputs[3].get_tensor_value()[0]
905+
dtype = node.inputs[2].get_tensor_type()
906+
eye = np.eye(depth, dtype=dtype)
907+
if on != 0:
908+
eye[eye == 1] = on
909+
eye[eye == 0] = off
910+
else:
911+
eye[eye == 0] = off
912+
eye[eye == 1] = on
913+
const_name = utils.make_name(node.name)
914+
ctx.make_const(const_name, "Const", eye)
915+
# setup gather inputs
916+
del node.input[:]
917+
node.input.append(const_name)
918+
node.input.append(indices_name)
919+
node.type = "Gather"
920+
if axis.i == 0:
921+
# TODO: revisit for rank > 1
922+
name = utils.make_name(node.name)
923+
transpose_op = ctx.insert_new_node_on_output("Transpose", node.output[0], name)
924+
ctx.copy_shape(node.output[0], transpose_op.output[0])
925+
return [node, transpose_op]
926+
return node
927+
928+
873929
# pylint: enable=W0613,C0111,W0612
874930

875931
# map tensorflow ops to onnx ops. The format below is
@@ -962,6 +1018,7 @@ def unpack_op(ctx, node, name, args):
9621018
_OPSET_5 = {
9631019
"Reshape": (reshape_op5, []),
9641020
"ExpandDims": (expanddims_op7, []),
1021+
"OneHot": (onehot_op, []),
9651022
}
9661023

9671024
_OPSET_6 = {
@@ -1183,7 +1240,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
11831240
def tf_optimize(sess, inputs, outputs, graph_def):
11841241
"""Optimize tensorflow graph for inference."""
11851242
transforms = [
1186-
#"fold_constants(ignore_errors=true)",
1243+
"fold_constants(ignore_errors=true)",
11871244
"fold_batch_norms",
11881245
"fold_old_batch_norms",
11891246
]

0 commit comments

Comments
 (0)