Skip to content

Commit 820acbc

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into input
2 parents 102d8f8 + d3d301a commit 820acbc

File tree

8 files changed

+43
-54
lines changed

8 files changed

+43
-54
lines changed

tests/test_optimizers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,6 @@ def test_duplicated_duplicated_attributes(self):
912912
op_type="ReduceSum", remaining_op_num=2)
913913

914914
def _check_initializer_num(self, graph_proto, num):
915-
print(len(graph_proto.initializer))
916915
return num == len(graph_proto.initializer)
917916

918917
def test_duplicated_duplicated_constant(self):

tf2onnx/graph_matcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, op_type, name=None, inputs=None):
5050
input_pattern if isinstance(input_pattern, OpTypePattern) else
5151
OpTypePattern(input_pattern) for input_pattern in inputs
5252
]
53+
self.op_type_set = set(op_type.split('|')) if op_type else set()
5354

5455
@property
5556
def op_type(self):
@@ -154,7 +155,7 @@ def _is_op_type_same(op, pattern):
154155
if pattern.op_type == "*":
155156
return True
156157

157-
if op.type in pattern.op_type.split('|'):
158+
if op.type in pattern.op_type_set:
158159
return True
159160

160161
return False

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import copy
1313
import logging
14-
import sys
1514

1615
import numpy as np
1716

@@ -524,7 +523,7 @@ def version_7(cls, ctx, node, **kwargs):
524523
maximum_iterations_name = node.input[1]
525524
maximum_iterations = node.inputs[1].get_tensor_value()
526525
if maximum_iterations == -1:
527-
maximum_iterations = sys.maxsize
526+
maximum_iterations = np.iinfo(np.int64).max
528527
consumers = ctx.find_output_consumers(maximum_iterations_name)
529528
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
530529
if len(external_consumers) == 0:

tf2onnx/onnx_opset/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def version_8(cls, ctx, node, **kwargs):
183183

184184
@classmethod
185185
def version_12(cls, ctx, node, **kwargs):
186-
node.name = 'Clip' # clip supports all types now
186+
node.type = 'Clip' # clip supports all types now
187187

188188
@tf_op("Softmax")
189189
class Softmax:
@@ -545,9 +545,9 @@ def version_11(cls, ctx, node, **kwargs):
545545
shapes=shapes, dtypes=dtypes, domain=constants.ONNX_DOMAIN, attr={'direction': direction})
546546

547547
if node.maybe_cast_input([supported, supported], type_map):
548-
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
549-
name=utils.make_name(node.name) + "_castback")
550-
cast_back_node.set_attr("to", dtypes[0])
548+
cast_back_node = ctx.insert_new_node_on_output(
549+
"Cast", node.output[0], name=utils.make_name(node.name) + "_castback",
550+
to=dtypes[0])
551551
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
552552
ctx.copy_shape(node.name, cast_back_node.output[0])
553553

tf2onnx/onnx_opset/nn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,13 @@ def version_1(cls, ctx, node, **kwargs):
637637
origin_dtype = ctx.get_dtype(node.output[0])
638638
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
639639
onnx_pb.TensorProto.DOUBLE]:
640-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
641-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
640+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
642641
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
643642
ctx.copy_shape(node.name, cast_node.output[0])
644643

645644
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
646-
name=utils.make_name(node.name) + "_castback")
647-
cast_back_node.set_attr("to", origin_dtype)
645+
name=utils.make_name(node.name) + "_castback",
646+
to=origin_dtype)
648647
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
649648
ctx.copy_shape(node.name, cast_back_node.output[0])
650649

@@ -667,14 +666,13 @@ def version_11(cls, ctx, node, **kwargs):
667666
origin_dtype = ctx.get_dtype(node.output[0])
668667
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
669668
TensorProto.INT32, TensorProto.INT64]:
670-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
671-
cast_node.set_attr("to", TensorProto.FLOAT)
669+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.FLOAT)
672670
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
673671
ctx.copy_shape(node.name, cast_node.output[0])
674672

675673
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
676-
name=utils.make_name(node.name) + "_castback")
677-
cast_back_node.set_attr("to", origin_dtype)
674+
name=utils.make_name(node.name) + "_castback",
675+
to=origin_dtype)
678676
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
679677
ctx.copy_shape(node.name, cast_back_node.output[0])
680678

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
3131
"""cast int32 shape into int64 shape."""
3232
name = node.input[input_number]
3333

34-
cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
35-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
34+
cast_node = ctx.insert_new_node_on_input(node, "Cast", name, to=onnx_pb.TensorProto.INT64)
3635
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
3736
ctx.copy_shape(name, cast_node.output[0])
3837

@@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
4645
output_name = node.output[0]
4746
# cast each inputs to float
4847
for i, inp in enumerate(node.inputs):
49-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
50-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
48+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i],
49+
to=onnx_pb.TensorProto.FLOAT)
5150
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
5251
next_nodes = ctx.find_output_consumers(node.output[0])
5352
# cast output back to dtype unless the next op is a cast
5453
if next_nodes[0].type != "Cast":
55-
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name())
56-
output_cast.set_attr("to", dtype)
54+
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
55+
to=dtype)
5756
ctx.set_dtype(output_cast.output[0], dtype)
5857
ctx.copy_shape(output_name, output_cast.output[0])
5958

@@ -157,15 +156,14 @@ def version_5(cls, ctx, node, **kwargs):
157156
return
158157

159158
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
160-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
161-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
159+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
162160
ctx.copy_shape(node.output[0], input_cast.output[0])
163161

164162
# if the next node is already a cast we don't need to insert another one
165163
next_nodes = ctx.find_output_consumers(node.output[0])
166164
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
167-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name())
168-
output_cast.set_attr("to", dtype)
165+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name(),
166+
to=dtype)
169167
ctx.set_dtype(output_cast.output[0], dtype)
170168
ctx.copy_shape(node.output[0], output_cast.output[0])
171169

@@ -742,16 +740,17 @@ def version_1(cls, ctx, node, **kwargs):
742740
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
743741
# override the previous cast
744742
cast_node = node.inputs[0]
743+
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
745744
else:
746-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
745+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0],
746+
to=onnx_pb.TensorProto.FLOAT)
747747
nodes.insert(0, cast_node)
748-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
749748
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
750749
ctx.copy_shape(node.input[0], cast_node.output[0])
751750
# undo the cast afer slice
752751
name = utils.make_name(node.name)
753-
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
754-
cast_node.set_attr("to", input_dtype)
752+
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name,
753+
to=input_dtype)
755754
ctx.set_dtype(cast_node.output[0], input_dtype)
756755
ctx.copy_shape(node.output[0], cast_node.output[0])
757756
nodes.append(cast_node)
@@ -1180,8 +1179,7 @@ def version_1(cls, ctx, node, **kwargs):
11801179
if dtype == onnx_pb.TensorProto.INT64:
11811180
return
11821181
op_name = utils.make_name(node.name)
1183-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name)
1184-
output_cast.set_attr("to", dtype)
1182+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, to=dtype)
11851183
ctx.set_dtype(output_cast.output[0], dtype)
11861184
ctx.copy_shape(node.output[0], output_cast.output[0])
11871185

@@ -1555,8 +1553,7 @@ def version_8(cls, ctx, node, **kwargs):
15551553

15561554
seq_len_dtype = ctx.get_dtype(node.input[1])
15571555
if seq_len_dtype != onnx_pb.TensorProto.INT64:
1558-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
1559-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
1556+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
15601557
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
15611558
ctx.copy_shape(node.input[1], cast_node.output[0])
15621559

@@ -1762,8 +1759,8 @@ def version_11(cls, ctx, node, **kwargs):
17621759
# cast to int64 if needed
17631760
if dtypes[1] != onnx_pb.TensorProto.UINT64:
17641761
cast_node = ctx.insert_new_node_on_output("Cast", node.output[1],
1765-
name=utils.make_name(node.name) + "_cast")
1766-
cast_node.set_attr("to", dtypes[1])
1762+
name=utils.make_name(node.name) + "_cast",
1763+
to=dtypes[1])
17671764
ctx.set_dtype(cast_node.output[0], dtypes[1])
17681765
ctx.copy_shape(node.output[1], cast_node.output[0])
17691766
# FIXME: the indices in onnx are not the same as in tensorflow.

tf2onnx/tf_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from distutils.version import LooseVersion
1414

1515
import numpy as np
16-
import six
1716
import tensorflow as tf
1817

1918
from tensorflow.core.framework import types_pb2, tensor_pb2
@@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor):
7069
"""Get data from tensor."""
7170
make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto")
7271
np_data = tensor_util.MakeNdarray(tensor)
73-
make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data))
72+
make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data)
7473
return np_data
7574

7675

@@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True):
8382
when as_list=False, return np.array(1), type is <class 'numpy.ndarray'>
8483
when as_list=True, return 1, type is <class 'int'>.
8584
"""
86-
make_sure(is_tf_const_op(op), "{} isn't a const op".format(op.name))
85+
make_sure(is_tf_const_op(op), "%r isn't a const op", op.name)
8786
value = get_tf_tensor_data(op.get_attr("value"))
8887
if as_list:
8988
value = value.tolist()
@@ -119,9 +118,6 @@ def map_tf_dtype(dtype):
119118

120119
def get_tf_node_attr(node, name):
121120
"""Parser TF node attribute."""
122-
if six.PY2:
123-
# For python2, TF get_attr does not accept unicode
124-
name = str(name)
125121
return node.get_attr(name)
126122

127123

@@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override):
136132
"""
137133

138134
# ignore the following attributes
139-
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
135+
ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
140136
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
141137
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
142138
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
143139
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
144140
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
145141
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
146-
"Toutput_types"]
142+
"Toutput_types"}
147143

148144
node_list = g.get_operations()
149145
functions = {}
@@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override):
176172
attr_cnt[a] += 1
177173
if a == "dtype":
178174
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
179-
elif a in ["T"]:
175+
elif a == "T":
180176
dtype = get_tf_node_attr(node, a)
181-
if dtype:
182-
if not isinstance(dtype, list):
183-
dtypes[node.name] = map_tf_dtype(dtype)
184-
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
177+
if dtype and not isinstance(dtype, list):
178+
dtypes[node.name] = map_tf_dtype(dtype)
179+
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}:
185180
# Tidx is used by Range
186181
# out_idx is used by ListDiff
187182
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
@@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override):
192187
elif a == "output_shapes":
193188
# we should not need it since we pull the shapes above already
194189
pass
195-
elif a in ["body", "cond", "then_branch", "else_branch"]:
190+
elif a in {"body", "cond", "then_branch", "else_branch"}:
196191
input_shapes = [inp.get_shape() for inp in node.inputs]
197192
nattr = get_tf_node_attr(node, a)
198193
attr[a] = nattr.name

tf2onnx/tfonnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
160160
input_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
161161
g.set_dtype(input_name, onnx_pb.TensorProto.FLOAT)
162162
else:
163-
cast_node = g.insert_new_node_on_input(op, "Cast", input_name)
164-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
163+
cast_node = g.insert_new_node_on_input(op, "Cast", input_name,
164+
to=onnx_pb.TensorProto.FLOAT)
165165
g.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
166166
g.copy_shape(input_name, cast_node.output[0])
167167
cast_inserted.append(cast_node)
@@ -171,8 +171,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
171171
name = utils.make_name(op.name)
172172
logger.debug("insert cast back for node %s on output %s [dtype=%s]", op.name, output_name,
173173
output_dtype)
174-
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name)
175-
output_cast.set_attr("to", output_dtype)
174+
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name,
175+
to=output_dtype)
176176
g.set_dtype(output_cast.output[0], output_dtype)
177177
g.copy_shape(output_name, output_cast.output[0])
178178
cast_inserted.append(output_cast)

0 commit comments

Comments
 (0)