28
28
# todo(pengwa): remove protected-access later
29
29
# pylint: disable=broad-except,protected-access
30
30
31
+ class ExternalTensorStorage ():
32
+ """Passed into graph and node methods to accumulate tensors to save externally"""
33
+ def __init__ (self ):
34
+ self .name_to_tensor_data = {}
35
+ self .name_counter = 0
36
+ self .external_tensor_size_threshold = 1024
37
+ self .node_to_modified_value_attr = {}
31
38
32
39
class Node (object ):
33
40
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -93,16 +100,40 @@ def inputs(self):
93
100
def attr (self ):
94
101
return self ._attr
95
102
96
- @property
97
- def attr_onnx (self ):
98
- """Return onnx valid attributes"""
103
+ def get_value_attr (self , external_tensor_storage = None ):
104
+ """Return onnx attr for value property of node.
105
+ Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
106
+ """
107
+ a = self ._attr ["value" ]
108
+ if external_tensor_storage is not None and self in external_tensor_storage .node_to_modified_value_attr :
109
+ return external_tensor_storage .node_to_modified_value_attr [self ]
110
+ if external_tensor_storage is None or a .type != AttributeProto .TENSOR :
111
+ return a
112
+ if np .product (a .t .dims ) > external_tensor_storage .external_tensor_size_threshold :
113
+ a = copy .copy (a )
114
+ tensor_name = self .name + "_" + str (external_tensor_storage .name_counter )
115
+ external_tensor_storage .name_counter += 1
116
+ external_tensor_storage .name_to_tensor_data [tensor_name ] = a .t .raw_data
117
+ external_tensor_storage .node_to_modified_value_attr [self ] = a
118
+ a .t .raw_data = b'__EXTERNAL'
119
+ location = a .t .external_data .add ()
120
+ location .key = "location"
121
+ location .value = tensor_name
122
+ a .t .data_location = TensorProto .EXTERNAL
123
+ return a
124
+
125
+ def get_onnx_attrs (self , external_tensor_storage = None ):
126
+ """Return onnx valid attributes.
127
+ Attrs point to external tensor data stored in external_tensor_storage, if included."""
99
128
schema = get_schema (self .type , self .graph .opset , self .domain )
100
129
if schema is None and not (self .is_const () or self .is_graph_input ()):
101
130
logger .debug ("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check" ,
102
131
self .name , self .domain , self .type )
103
132
onnx_attrs = {}
104
133
for a in self ._attr .values ():
105
- if schema is None or schema .has_attribute (a .name ):
134
+ if a .name == "value" :
135
+ onnx_attrs [a .name ] = self .get_value_attr (external_tensor_storage )
136
+ elif schema is None or schema .has_attribute (a .name ):
106
137
onnx_attrs [a .name ] = a
107
138
return onnx_attrs
108
139
@@ -333,7 +364,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
333
364
self .graph .contained_graphs [self .name ].update ({attr_name : graph })
334
365
graph .parent_graph = self .graph
335
366
336
- def update_proto (self ):
367
+ def update_proto (self , external_tensor_storage = None ):
337
368
"""Update protobuf from internal structure."""
338
369
nodes = list (self ._op .input )
339
370
for node in nodes :
@@ -351,10 +382,10 @@ def update_proto(self):
351
382
attr_graphs = self .get_body_graphs ()
352
383
if attr_graphs :
353
384
for attr_name , sub_graph in attr_graphs .items ():
354
- graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name )
385
+ graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name , external_tensor_storage )
355
386
self .set_attr (attr_name , graph_proto )
356
387
357
- attr = list (self .attr_onnx .values ())
388
+ attr = list (self .get_onnx_attrs ( external_tensor_storage ) .values ())
358
389
if attr :
359
390
self ._op .attribute .extend (attr )
360
391
@@ -748,10 +779,10 @@ def update_node_shape_dtype(self, node, override=False):
748
779
self .set_shape (output , shape )
749
780
logger .debug ("Set shape of [%s] to %s" , output , shape )
750
781
751
- def update_proto (self ):
782
+ def update_proto (self , external_tensor_storage = None ):
752
783
"""Update the onnx protobuf from out internal Node structure."""
753
784
for node in self ._nodes :
754
- node .update_proto ()
785
+ node .update_proto (external_tensor_storage )
755
786
756
787
def get_nodes (self ):
757
788
"""Get node list."""
@@ -968,7 +999,7 @@ def _get_unvisited_child(g, node, not_visited):
968
999
ret = [x for _ , x in sorted (zip (label , ops ))]
969
1000
self .reset_nodes (ret )
970
1001
971
- def make_graph (self , doc , graph_name = None ):
1002
+ def make_graph (self , doc , graph_name = None , external_tensor_storage = None ):
972
1003
"""
973
1004
Create GraphProto for onnx from internal graph.
974
1005
Args:
@@ -978,7 +1009,7 @@ def make_graph(self, doc, graph_name=None):
978
1009
graph_name = graph_name or self .graph_name
979
1010
self .delete_unused_nodes (self .outputs )
980
1011
self .topological_sort (self .get_nodes ())
981
- self .update_proto ()
1012
+ self .update_proto (external_tensor_storage )
982
1013
983
1014
# TODO: we'd want to do something like this so that transpose optimizer is active
984
1015
# for all (unit) tests
@@ -1021,7 +1052,7 @@ def make_graph(self, doc, graph_name=None):
1021
1052
# not to use numpy_helper.from_array to create a new tensor
1022
1053
# because sometimes onnx will have a bug that only check the tensor data in specific field
1023
1054
# such as at upsample it only checks the float_data field.
1024
- t = op .get_attr ( "value" )
1055
+ t = op .get_value_attr ( external_tensor_storage )
1025
1056
tensor = helper .get_attribute_value (t )
1026
1057
tensor .name = op .output [0 ]
1027
1058
initializers .append (tensor )
@@ -1050,14 +1081,14 @@ def make_graph(self, doc, graph_name=None):
1050
1081
1051
1082
return graph
1052
1083
1053
- def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , ** kwargs ):
1084
+ def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , external_tensor_storage = None , ** kwargs ):
1054
1085
"""
1055
1086
Create final ModelProto for onnx from internal graph.
1056
1087
Args:
1057
1088
optimize: optimize graph via onnx
1058
1089
doc: text for doc string of the model
1059
1090
"""
1060
- graph = self .make_graph (graph_doc , graph_name )
1091
+ graph = self .make_graph (graph_doc , graph_name , external_tensor_storage )
1061
1092
1062
1093
if "producer_name" not in kwargs :
1063
1094
kwargs = {"producer_name" : "tf2onnx" ,
0 commit comments