Skip to content

Commit b2b81a3

Browse files
Added graph methods for saving using the external data storage format
1 parent 45c2ec7 commit b2b81a3

File tree

5 files changed

+52
-21
lines changed

5 files changed

+52
-21
lines changed

tests/test_internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_node_attr_onnx(self):
226226
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
227227
n1 = g.get_node_by_name("n1")
228228
self.assertTrue("my_attr" in n1.attr)
229-
self.assertTrue("my_attr" not in n1.attr_onnx)
229+
self.assertTrue("my_attr" not in n1.get_onnx_attrs())
230230

231231
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
232232
graph_proto = helper.make_graph(
@@ -240,7 +240,7 @@ def test_node_attr_onnx(self):
240240
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
241241
n1 = g.get_node_by_name("n1")
242242
self.assertTrue("my_attr" in n1.attr)
243-
self.assertTrue("my_attr" in n1.attr_onnx)
243+
self.assertTrue("my_attr" in n1.get_onnx_attrs())
244244

245245
def test_tensor_data(self):
246246
tensors = {

tf2onnx/graph.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
# todo(pengwa): remove protected-access later
2929
# pylint: disable=broad-except,protected-access
3030

31+
class ExternalTensorStorage():
32+
"""Passed into graph and node methods to accumulate tensors to save externally"""
33+
def __init__(self):
34+
self.name_to_tensor_data = {}
35+
self.name_counter = 0
36+
self.external_tensor_size_threshold = 1024
37+
self.node_to_modified_value_attr = {}
3138

3239
class Node(object):
3340
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -88,16 +95,40 @@ def inputs(self):
8895
def attr(self):
8996
return self._attr
9097

91-
@property
92-
def attr_onnx(self):
93-
"""Return onnx valid attributes"""
98+
def get_value_attr(self, external_tensor_storage=None):
99+
"""Return onnx attr for value property of node.
100+
Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
101+
"""
102+
a = self._attr["value"]
103+
if external_tensor_storage is not None and self in external_tensor_storage.node_to_modified_value_attr:
104+
return external_tensor_storage.node_to_modified_value_attr[self]
105+
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
106+
return a
107+
if np.product(a.t.dims) > external_tensor_storage.external_tensor_size_threshold:
108+
a = copy.copy(a)
109+
tensor_name = self.name + "_" + str(external_tensor_storage.name_counter)
110+
external_tensor_storage.name_counter += 1
111+
external_tensor_storage.name_to_tensor_data[tensor_name] = a.t.raw_data
112+
external_tensor_storage.node_to_modified_value_attr[self] = a
113+
a.t.raw_data = b'__EXTERNAL'
114+
location = a.t.external_data.add()
115+
location.key = "location"
116+
location.value = tensor_name
117+
a.t.data_location = TensorProto.EXTERNAL
118+
return a
119+
120+
def get_onnx_attrs(self, external_tensor_storage=None):
121+
"""Return onnx valid attributes.
122+
Attrs point to external tensor data stored in external_tensor_storage, if included."""
94123
schema = get_schema(self.type, self.graph.opset, self.domain)
95124
if schema is None and not (self.is_const() or self.is_graph_input()):
96125
logger.debug("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check",
97126
self.name, self.domain, self.type)
98127
onnx_attrs = {}
99128
for a in self._attr.values():
100-
if schema is None or schema.has_attribute(a.name):
129+
if a.name == "value":
130+
onnx_attrs[a.name] = self.get_value_attr(external_tensor_storage)
131+
elif schema is None or schema.has_attribute(a.name):
101132
onnx_attrs[a.name] = a
102133
return onnx_attrs
103134

@@ -328,7 +359,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
328359
self.graph.contained_graphs[self.name].update({attr_name: graph})
329360
graph.parent_graph = self.graph
330361

331-
def update_proto(self):
362+
def update_proto(self, external_tensor_storage=None):
332363
"""Update protobuf from internal structure."""
333364
nodes = list(self._op.input)
334365
for node in nodes:
@@ -346,10 +377,10 @@ def update_proto(self):
346377
attr_graphs = self.get_body_graphs()
347378
if attr_graphs:
348379
for attr_name, sub_graph in attr_graphs.items():
349-
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name)
380+
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name, external_tensor_storage)
350381
self.set_attr(attr_name, graph_proto)
351382

352-
attr = list(self.attr_onnx.values())
383+
attr = list(self.get_onnx_attrs(external_tensor_storage).values())
353384
if attr:
354385
self._op.attribute.extend(attr)
355386

@@ -743,10 +774,10 @@ def update_node_shape_dtype(self, node, override=False):
743774
self.set_shape(output, shape)
744775
logger.debug("Set shape of [%s] to %s", output, shape)
745776

746-
def update_proto(self):
777+
def update_proto(self, external_tensor_storage=None):
747778
"""Update the onnx protobuf from out internal Node structure."""
748779
for node in self._nodes:
749-
node.update_proto()
780+
node.update_proto(external_tensor_storage)
750781

751782
def get_nodes(self):
752783
"""Get node list."""
@@ -963,7 +994,7 @@ def _get_unvisited_child(g, node, not_visited):
963994
ret = [x for _, x in sorted(zip(label, ops))]
964995
self.reset_nodes(ret)
965996

966-
def make_graph(self, doc, graph_name=None):
997+
def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
967998
"""
968999
Create GraphProto for onnx from internal graph.
9691000
Args:
@@ -973,7 +1004,7 @@ def make_graph(self, doc, graph_name=None):
9731004
graph_name = graph_name or self.graph_name
9741005
self.delete_unused_nodes(self.outputs)
9751006
self.topological_sort(self.get_nodes())
976-
self.update_proto()
1007+
self.update_proto(external_tensor_storage)
9771008

9781009
# TODO: we'd want to do something like this so that transpose optimizer is active
9791010
# for all (unit) tests
@@ -1016,7 +1047,7 @@ def make_graph(self, doc, graph_name=None):
10161047
# not to use numpy_helper.from_array to create a new tensor
10171048
# because sometimes onnx will have a bug that only check the tensor data in specific field
10181049
# such as at upsample it only checks the float_data field.
1019-
t = op.get_attr("value")
1050+
t = op.get_value_attr(external_tensor_storage)
10201051
tensor = helper.get_attribute_value(t)
10211052
tensor.name = op.output[0]
10221053
initializers.append(tensor)
@@ -1045,14 +1076,14 @@ def make_graph(self, doc, graph_name=None):
10451076

10461077
return graph
10471078

1048-
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
1079+
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", external_tensor_storage=None, **kwargs):
10491080
"""
10501081
Create final ModelProto for onnx from internal graph.
10511082
Args:
10521083
optimize: optimize graph via onnx
10531084
doc: text for doc string of the model
10541085
"""
1055-
graph = self.make_graph(graph_doc, graph_name)
1086+
graph = self.make_graph(graph_doc, graph_name, external_tensor_storage)
10561087

10571088
if "producer_name" not in kwargs:
10581089
kwargs = {"producer_name": "tf2onnx",

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _handle_nhwc_tranpose(self, trans):
282282
return False
283283
# move transpose into branches to let Transposes can be "handled" in each branch
284284
for n in out_nodes:
285-
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.attr_onnx)
285+
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
286286
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])
287287

288288
self._g.remove_node(trans.name)
@@ -407,7 +407,7 @@ def _add_handler(self, trans, node):
407407
target_node.set_tensor_value(target_val)
408408

409409
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
410-
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
410+
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.get_onnx_attrs())
411411
ops = self._g.get_nodes()
412412
trans.input[0] = utils.port_name(conv_node.name)
413413
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])

tf2onnx/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def build_onnx_op(node):
136136
copied_sub_graph = copy.deepcopy(sub_graph)
137137
graph_proto = copied_sub_graph.make_graph("graph for " + node.name + " " + attr_name)
138138
attr.append(helper.make_attribute(attr_name, graph_proto))
139-
attr.extend(node.attr_onnx.values())
139+
attr.extend(node.get_onnx_attrs().values())
140140
if attr:
141141
onnx_node.attribute.extend(attr)
142142
return onnx_node

tf2onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,9 @@ def is_same(node_1, node_2):
405405
if node_1.type != node_2.type:
406406
return False
407407
# check onnx attributes
408-
if node_1.attr_onnx.keys() != node_2.attr_onnx.keys():
408+
if node_1.get_onnx_attrs().keys() != node_2.get_onnx_attrs().keys():
409409
return False
410-
for name in node_1.attr_onnx.keys(): # pylint: disable=consider-iterating-dictionary
410+
for name in node_1.get_onnx_attrs().keys(): # pylint: disable=consider-iterating-dictionary
411411
if node_1.get_attr_value(name) != node_2.get_attr_value(name):
412412
return False
413413
return True

0 commit comments

Comments
 (0)