Skip to content

Commit 89ad32a

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into input2
2 parents 094ab85 + 62f1e70 commit 89ad32a

File tree

9 files changed

+79
-29
lines changed

9 files changed

+79
-29
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
293293
onnx graph
294294
"""
295295
```
296-
For example in [examples/call_coverter_via_python.py]():
296+
For example in [examples/call_converter_via_python.py]():
297297
```
298298
import tensorflow as tf
299299
import tf2onnx

tests/test_internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_node_attr_onnx(self):
226226
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
227227
n1 = g.get_node_by_name("n1")
228228
self.assertTrue("my_attr" in n1.attr)
229-
self.assertTrue("my_attr" not in n1.attr_onnx)
229+
self.assertTrue("my_attr" not in n1.get_onnx_attrs())
230230

231231
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
232232
graph_proto = helper.make_graph(
@@ -240,7 +240,7 @@ def test_node_attr_onnx(self):
240240
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
241241
n1 = g.get_node_by_name("n1")
242242
self.assertTrue("my_attr" in n1.attr)
243-
self.assertTrue("my_attr" in n1.attr_onnx)
243+
self.assertTrue("my_attr" in n1.get_onnx_attrs())
244244

245245
def test_tensor_data(self):
246246
tensors = {

tf2onnx/graph.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
# todo(pengwa): remove protected-access later
2929
# pylint: disable=broad-except,protected-access
3030

31+
class ExternalTensorStorage():
32+
"""Passed into graph and node methods to accumulate tensors to save externally"""
33+
def __init__(self):
34+
self.name_to_tensor_data = {}
35+
self.name_counter = 0
36+
self.external_tensor_size_threshold = 1024
37+
self.node_to_modified_value_attr = {}
3138

3239
class Node(object):
3340
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -93,16 +100,40 @@ def inputs(self):
93100
def attr(self):
94101
return self._attr
95102

96-
@property
97-
def attr_onnx(self):
98-
"""Return onnx valid attributes"""
103+
def get_value_attr(self, external_tensor_storage=None):
104+
"""Return onnx attr for value property of node.
105+
Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
106+
"""
107+
a = self._attr["value"]
108+
if external_tensor_storage is not None and self in external_tensor_storage.node_to_modified_value_attr:
109+
return external_tensor_storage.node_to_modified_value_attr[self]
110+
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
111+
return a
112+
if np.product(a.t.dims) > external_tensor_storage.external_tensor_size_threshold:
113+
a = copy.copy(a)
114+
tensor_name = self.name + "_" + str(external_tensor_storage.name_counter)
115+
external_tensor_storage.name_counter += 1
116+
external_tensor_storage.name_to_tensor_data[tensor_name] = a.t.raw_data
117+
external_tensor_storage.node_to_modified_value_attr[self] = a
118+
a.t.raw_data = b'__EXTERNAL'
119+
location = a.t.external_data.add()
120+
location.key = "location"
121+
location.value = tensor_name
122+
a.t.data_location = TensorProto.EXTERNAL
123+
return a
124+
125+
def get_onnx_attrs(self, external_tensor_storage=None):
126+
"""Return onnx valid attributes.
127+
Attrs point to external tensor data stored in external_tensor_storage, if included."""
99128
schema = get_schema(self.type, self.graph.opset, self.domain)
100129
if schema is None and not (self.is_const() or self.is_graph_input()):
101130
logger.debug("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check",
102131
self.name, self.domain, self.type)
103132
onnx_attrs = {}
104133
for a in self._attr.values():
105-
if schema is None or schema.has_attribute(a.name):
134+
if a.name == "value":
135+
onnx_attrs[a.name] = self.get_value_attr(external_tensor_storage)
136+
elif schema is None or schema.has_attribute(a.name):
106137
onnx_attrs[a.name] = a
107138
return onnx_attrs
108139

@@ -333,7 +364,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
333364
self.graph.contained_graphs[self.name].update({attr_name: graph})
334365
graph.parent_graph = self.graph
335366

336-
def update_proto(self):
367+
def update_proto(self, external_tensor_storage=None):
337368
"""Update protobuf from internal structure."""
338369
nodes = list(self._op.input)
339370
for node in nodes:
@@ -351,10 +382,10 @@ def update_proto(self):
351382
attr_graphs = self.get_body_graphs()
352383
if attr_graphs:
353384
for attr_name, sub_graph in attr_graphs.items():
354-
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name)
385+
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name, external_tensor_storage)
355386
self.set_attr(attr_name, graph_proto)
356387

357-
attr = list(self.attr_onnx.values())
388+
attr = list(self.get_onnx_attrs(external_tensor_storage).values())
358389
if attr:
359390
self._op.attribute.extend(attr)
360391

@@ -766,10 +797,10 @@ def update_node_shape_dtype(self, node, override=False):
766797
self.set_shape(output, shape)
767798
logger.debug("Set shape of [%s] to %s", output, shape)
768799

769-
def update_proto(self):
800+
def update_proto(self, external_tensor_storage=None):
770801
"""Update the onnx protobuf from out internal Node structure."""
771802
for node in self._nodes:
772-
node.update_proto()
803+
node.update_proto(external_tensor_storage)
773804

774805
def get_nodes(self):
775806
"""Get node list."""
@@ -988,7 +1019,7 @@ def _get_unvisited_child(g, node, not_visited):
9881019
ret = [x for _, x in sorted(zip(label, ops))]
9891020
self.reset_nodes(ret)
9901021

991-
def make_graph(self, doc, graph_name=None):
1022+
def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
9921023
"""
9931024
Create GraphProto for onnx from internal graph.
9941025
Args:
@@ -998,7 +1029,7 @@ def make_graph(self, doc, graph_name=None):
9981029
graph_name = graph_name or self.graph_name
9991030
self.delete_unused_nodes(self.outputs)
10001031
self.topological_sort(self.get_nodes())
1001-
self.update_proto()
1032+
self.update_proto(external_tensor_storage)
10021033

10031034
# TODO: we'd want to do something like this so that transpose optimizer is active
10041035
# for all (unit) tests
@@ -1041,7 +1072,7 @@ def make_graph(self, doc, graph_name=None):
10411072
# not to use numpy_helper.from_array to create a new tensor
10421073
# because sometimes onnx will have a bug that only check the tensor data in specific field
10431074
# such as at upsample it only checks the float_data field.
1044-
t = op.get_attr("value")
1075+
t = op.get_value_attr(external_tensor_storage)
10451076
tensor = helper.get_attribute_value(t)
10461077
tensor.name = op.output[0]
10471078
initializers.append(tensor)
@@ -1070,14 +1101,14 @@ def make_graph(self, doc, graph_name=None):
10701101

10711102
return graph
10721103

1073-
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
1104+
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", external_tensor_storage=None, **kwargs):
10741105
"""
10751106
Create final ModelProto for onnx from internal graph.
10761107
Args:
10771108
optimize: optimize graph via onnx
10781109
doc: text for doc string of the model
10791110
"""
1080-
graph = self.make_graph(graph_doc, graph_name)
1111+
graph = self.make_graph(graph_doc, graph_name, external_tensor_storage)
10811112

10821113
if "producer_name" not in kwargs:
10831114
kwargs = {"producer_name": "tf2onnx",

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class RealDiv(common.BroadcastOp):
3232
pass
3333

3434

35-
@tf_op(["LeakyRelu", "LogSoftmax", "Softplus", "Softsign"])
35+
@tf_op(["LeakyRelu", "Softplus", "Softsign"])
3636
class DirectOpSinceOpset1:
3737
@classmethod
3838
def version_1(cls, ctx, node, **kwargs):
@@ -185,7 +185,7 @@ def version_8(cls, ctx, node, **kwargs):
185185
def version_12(cls, ctx, node, **kwargs):
186186
node.type = 'Clip' # clip supports all types now
187187

188-
@tf_op("Softmax")
188+
@tf_op(["LogSoftmax", "Softmax"])
189189
class Softmax:
190190
@classmethod
191191
def version_1(cls, ctx, node, **kwargs):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _handle_nhwc_tranpose(self, trans):
281281
return False
282282
# move transpose into branches to let Transposes can be "handled" in each branch
283283
for n in out_nodes:
284-
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.attr_onnx)
284+
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
285285
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])
286286

287287
self._g.remove_node(trans.name)
@@ -406,7 +406,7 @@ def _add_handler(self, trans, node):
406406
target_node.set_tensor_value(target_val)
407407

408408
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
409-
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
409+
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.get_onnx_attrs())
410410
self._g.replace_input(trans, trans.input[0], utils.port_name(conv_node.name), 0)
411411
self._g.replace_all_inputs(node.output[0], trans.output[0]) # ops=self._g.get_nodes()
412412
self._g.remove_node(t_p.name)

tf2onnx/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def build_onnx_op(node):
136136
copied_sub_graph = copy.deepcopy(sub_graph)
137137
graph_proto = copied_sub_graph.make_graph("graph for " + node.name + " " + attr_name)
138138
attr.append(helper.make_attribute(attr_name, graph_proto))
139-
attr.extend(node.attr_onnx.values())
139+
attr.extend(node.get_onnx_attrs().values())
140140
if attr:
141141
onnx_node.attribute.extend(attr)
142142
return onnx_node

tf2onnx/tf_utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,24 @@ def get_tf_node_attr(node, name):
124124
def get_tf_version():
125125
return LooseVersion(tf.__version__)
126126

127-
128-
def tflist_to_onnx(g, shape_override):
127+
def compress_graph_def(graph_def):
128+
"""
129+
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
130+
"""
131+
node_defs = list(graph_def.node)
132+
const_node_values = {}
133+
for node_def in node_defs:
134+
if node_def.op == 'Const':
135+
tensor = node_def.attr["value"].tensor
136+
# Small constants are sometimes used to store shape information and must be maintained
137+
if len(tensor.tensor_content) > 1000:
138+
make_sure(node_def.name not in const_node_values, "Two nodes in graph have same name %s", node_def.name)
139+
const_node_values[node_def.name] = tensor.tensor_content
140+
tensor.tensor_content = b''
141+
return const_node_values
142+
143+
144+
def tflist_to_onnx(g, shape_override, const_node_values=None):
129145
"""
130146
Convert the tf-node list into an onnx graph with minimal rewrites so
131147
we can use the onnx graph as intermediate graph.
@@ -193,7 +209,10 @@ def tflist_to_onnx(g, shape_override):
193209
attr[a] = nattr.name
194210
functions[nattr.name] = input_shapes
195211
elif a == "value":
196-
onnx_tensor = tf_to_onnx_tensor(get_tf_node_attr(node, a), name=port_name(node.name))
212+
tensor = get_tf_node_attr(node, a)
213+
if const_node_values and node.name in const_node_values:
214+
tensor.tensor_content = const_node_values[node.name]
215+
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
197216
attr[a] = onnx_tensor
198217
elif a == "DstT":
199218
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
@@ -217,8 +236,8 @@ def tflist_to_onnx(g, shape_override):
217236
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
218237

219238

220-
def tensorflow_to_onnx(graph, shape_override):
239+
def tensorflow_to_onnx(graph, shape_override, const_node_values=None):
221240
"""
222241
Load tensorflow graph and do a conversion.
223242
"""
224-
return tflist_to_onnx(graph, shape_override)
243+
return tflist_to_onnx(graph, shape_override, const_node_values)

tf2onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,9 @@ def is_same(node_1, node_2):
405405
if node_1.type != node_2.type:
406406
return False
407407
# check onnx attributes
408-
if node_1.attr_onnx.keys() != node_2.attr_onnx.keys():
408+
if node_1.get_onnx_attrs().keys() != node_2.get_onnx_attrs().keys():
409409
return False
410-
for name in node_1.attr_onnx.keys(): # pylint: disable=consider-iterating-dictionary
410+
for name in node_1.get_onnx_attrs().keys(): # pylint: disable=consider-iterating-dictionary
411411
if node_1.get_attr_value(name) != node_2.get_attr_value(name):
412412
return False
413413
return True

0 commit comments

Comments
 (0)