28
28
# todo(pengwa): remove protected-access later
29
29
# pylint: disable=broad-except,protected-access
30
30
31
+ class ExternalTensorStorage ():
32
+ """Passed into graph and node methods to accumulate tensors to save externally"""
33
+ def __init__ (self ):
34
+ self .name_to_tensor_data = {}
35
+ self .name_counter = 0
36
+ self .external_tensor_size_threshold = 1024
37
+ self .node_to_modified_value_attr = {}
31
38
32
39
class Node (object ):
33
40
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -88,16 +95,40 @@ def inputs(self):
88
95
def attr (self ):
89
96
return self ._attr
90
97
91
- @property
92
- def attr_onnx (self ):
93
- """Return onnx valid attributes"""
98
+ def get_value_attr (self , external_tensor_storage = None ):
99
+ """Return onnx attr for value property of node.
100
+ Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
101
+ """
102
+ a = self ._attr ["value" ]
103
+ if external_tensor_storage is not None and self in external_tensor_storage .node_to_modified_value_attr :
104
+ return external_tensor_storage .node_to_modified_value_attr [self ]
105
+ if external_tensor_storage is None or a .type != AttributeProto .TENSOR :
106
+ return a
107
+ if np .product (a .t .dims ) > external_tensor_storage .external_tensor_size_threshold :
108
+ a = copy .copy (a )
109
+ tensor_name = self .name + "_" + str (external_tensor_storage .name_counter )
110
+ external_tensor_storage .name_counter += 1
111
+ external_tensor_storage .name_to_tensor_data [tensor_name ] = a .t .raw_data
112
+ external_tensor_storage .node_to_modified_value_attr [self ] = a
113
+ a .t .raw_data = b'__EXTERNAL'
114
+ location = a .t .external_data .add ()
115
+ location .key = "location"
116
+ location .value = tensor_name
117
+ a .t .data_location = TensorProto .EXTERNAL
118
+ return a
119
+
120
+ def get_onnx_attrs (self , external_tensor_storage = None ):
121
+ """Return onnx valid attributes.
122
+ Attrs point to external tensor data stored in external_tensor_storage, if included."""
94
123
schema = get_schema (self .type , self .graph .opset , self .domain )
95
124
if schema is None and not (self .is_const () or self .is_graph_input ()):
96
125
logger .debug ("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check" ,
97
126
self .name , self .domain , self .type )
98
127
onnx_attrs = {}
99
128
for a in self ._attr .values ():
100
- if schema is None or schema .has_attribute (a .name ):
129
+ if a .name == "value" :
130
+ onnx_attrs [a .name ] = self .get_value_attr (external_tensor_storage )
131
+ elif schema is None or schema .has_attribute (a .name ):
101
132
onnx_attrs [a .name ] = a
102
133
return onnx_attrs
103
134
@@ -328,7 +359,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
328
359
self .graph .contained_graphs [self .name ].update ({attr_name : graph })
329
360
graph .parent_graph = self .graph
330
361
331
- def update_proto (self ):
362
+ def update_proto (self , external_tensor_storage = None ):
332
363
"""Update protobuf from internal structure."""
333
364
nodes = list (self ._op .input )
334
365
for node in nodes :
@@ -346,10 +377,10 @@ def update_proto(self):
346
377
attr_graphs = self .get_body_graphs ()
347
378
if attr_graphs :
348
379
for attr_name , sub_graph in attr_graphs .items ():
349
- graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name )
380
+ graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name , external_tensor_storage )
350
381
self .set_attr (attr_name , graph_proto )
351
382
352
- attr = list (self .attr_onnx .values ())
383
+ attr = list (self .get_onnx_attrs ( external_tensor_storage ) .values ())
353
384
if attr :
354
385
self ._op .attribute .extend (attr )
355
386
@@ -743,10 +774,10 @@ def update_node_shape_dtype(self, node, override=False):
743
774
self .set_shape (output , shape )
744
775
logger .debug ("Set shape of [%s] to %s" , output , shape )
745
776
746
- def update_proto (self ):
777
+ def update_proto (self , external_tensor_storage = None ):
747
778
"""Update the onnx protobuf from out internal Node structure."""
748
779
for node in self ._nodes :
749
- node .update_proto ()
780
+ node .update_proto (external_tensor_storage )
750
781
751
782
def get_nodes (self ):
752
783
"""Get node list."""
@@ -963,7 +994,7 @@ def _get_unvisited_child(g, node, not_visited):
963
994
ret = [x for _ , x in sorted (zip (label , ops ))]
964
995
self .reset_nodes (ret )
965
996
966
- def make_graph (self , doc , graph_name = None ):
997
+ def make_graph (self , doc , graph_name = None , external_tensor_storage = None ):
967
998
"""
968
999
Create GraphProto for onnx from internal graph.
969
1000
Args:
@@ -973,7 +1004,7 @@ def make_graph(self, doc, graph_name=None):
973
1004
graph_name = graph_name or self .graph_name
974
1005
self .delete_unused_nodes (self .outputs )
975
1006
self .topological_sort (self .get_nodes ())
976
- self .update_proto ()
1007
+ self .update_proto (external_tensor_storage )
977
1008
978
1009
# TODO: we'd want to do something like this so that transpose optimizer is active
979
1010
# for all (unit) tests
@@ -1016,7 +1047,7 @@ def make_graph(self, doc, graph_name=None):
1016
1047
# not to use numpy_helper.from_array to create a new tensor
1017
1048
# because sometimes onnx will have a bug that only check the tensor data in specific field
1018
1049
# such as at upsample it only checks the float_data field.
1019
- t = op .get_attr ( "value" )
1050
+ t = op .get_value_attr ( external_tensor_storage )
1020
1051
tensor = helper .get_attribute_value (t )
1021
1052
tensor .name = op .output [0 ]
1022
1053
initializers .append (tensor )
@@ -1045,14 +1076,14 @@ def make_graph(self, doc, graph_name=None):
1045
1076
1046
1077
return graph
1047
1078
1048
- def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , ** kwargs ):
1079
+ def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , external_tensor_storage = None , ** kwargs ):
1049
1080
"""
1050
1081
Create final ModelProto for onnx from internal graph.
1051
1082
Args:
1052
1083
optimize: optimize graph via onnx
1053
1084
doc: text for doc string of the model
1054
1085
"""
1055
- graph = self .make_graph (graph_doc , graph_name )
1086
+ graph = self .make_graph (graph_doc , graph_name , external_tensor_storage )
1056
1087
1057
1088
if "producer_name" not in kwargs :
1058
1089
kwargs = {"producer_name" : "tf2onnx" ,
0 commit comments