28
28
# todo(pengwa): remove protected-access later
29
29
# pylint: disable=broad-except,protected-access
30
30
31
+ class ExternalTensorStorage ():
32
+ """Passed into graph and node methods to accumulate tensors to save externally"""
33
+ def __init__ (self ):
34
+ self .name_to_tensor_data = {}
35
+ self .name_counter = 0
36
+ self .external_tensor_size_threshold = 1024
37
+ self .node_to_modified_value_attr = {}
31
38
32
39
class Node (object ):
33
40
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -93,16 +100,40 @@ def inputs(self):
93
100
def attr (self ):
94
101
return self ._attr
95
102
96
- @property
97
- def attr_onnx (self ):
98
- """Return onnx valid attributes"""
103
+ def get_value_attr (self , external_tensor_storage = None ):
104
+ """Return onnx attr for value property of node.
105
+ Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
106
+ """
107
+ a = self ._attr ["value" ]
108
+ if external_tensor_storage is not None and self in external_tensor_storage .node_to_modified_value_attr :
109
+ return external_tensor_storage .node_to_modified_value_attr [self ]
110
+ if external_tensor_storage is None or a .type != AttributeProto .TENSOR :
111
+ return a
112
+ if np .product (a .t .dims ) > external_tensor_storage .external_tensor_size_threshold :
113
+ a = copy .copy (a )
114
+ tensor_name = self .name + "_" + str (external_tensor_storage .name_counter )
115
+ external_tensor_storage .name_counter += 1
116
+ external_tensor_storage .name_to_tensor_data [tensor_name ] = a .t .raw_data
117
+ external_tensor_storage .node_to_modified_value_attr [self ] = a
118
+ a .t .raw_data = b'__EXTERNAL'
119
+ location = a .t .external_data .add ()
120
+ location .key = "location"
121
+ location .value = tensor_name
122
+ a .t .data_location = TensorProto .EXTERNAL
123
+ return a
124
+
125
+ def get_onnx_attrs (self , external_tensor_storage = None ):
126
+ """Return onnx valid attributes.
127
+ Attrs point to external tensor data stored in external_tensor_storage, if included."""
99
128
schema = get_schema (self .type , self .graph .opset , self .domain )
100
129
if schema is None and not (self .is_const () or self .is_graph_input ()):
101
130
logger .debug ("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check" ,
102
131
self .name , self .domain , self .type )
103
132
onnx_attrs = {}
104
133
for a in self ._attr .values ():
105
- if schema is None or schema .has_attribute (a .name ):
134
+ if a .name == "value" :
135
+ onnx_attrs [a .name ] = self .get_value_attr (external_tensor_storage )
136
+ elif schema is None or schema .has_attribute (a .name ):
106
137
onnx_attrs [a .name ] = a
107
138
return onnx_attrs
108
139
@@ -333,7 +364,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
333
364
self .graph .contained_graphs [self .name ].update ({attr_name : graph })
334
365
graph .parent_graph = self .graph
335
366
336
- def update_proto (self ):
367
+ def update_proto (self , external_tensor_storage = None ):
337
368
"""Update protobuf from internal structure."""
338
369
nodes = list (self ._op .input )
339
370
for node in nodes :
@@ -351,10 +382,10 @@ def update_proto(self):
351
382
attr_graphs = self .get_body_graphs ()
352
383
if attr_graphs :
353
384
for attr_name , sub_graph in attr_graphs .items ():
354
- graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name )
385
+ graph_proto = sub_graph .make_graph ("graph for " + self .name + " " + attr_name , external_tensor_storage )
355
386
self .set_attr (attr_name , graph_proto )
356
387
357
- attr = list (self .attr_onnx .values ())
388
+ attr = list (self .get_onnx_attrs ( external_tensor_storage ) .values ())
358
389
if attr :
359
390
self ._op .attribute .extend (attr )
360
391
@@ -766,10 +797,10 @@ def update_node_shape_dtype(self, node, override=False):
766
797
self .set_shape (output , shape )
767
798
logger .debug ("Set shape of [%s] to %s" , output , shape )
768
799
769
- def update_proto (self ):
800
+ def update_proto (self , external_tensor_storage = None ):
770
801
"""Update the onnx protobuf from out internal Node structure."""
771
802
for node in self ._nodes :
772
- node .update_proto ()
803
+ node .update_proto (external_tensor_storage )
773
804
774
805
def get_nodes (self ):
775
806
"""Get node list."""
@@ -988,7 +1019,7 @@ def _get_unvisited_child(g, node, not_visited):
988
1019
ret = [x for _ , x in sorted (zip (label , ops ))]
989
1020
self .reset_nodes (ret )
990
1021
991
- def make_graph (self , doc , graph_name = None ):
1022
+ def make_graph (self , doc , graph_name = None , external_tensor_storage = None ):
992
1023
"""
993
1024
Create GraphProto for onnx from internal graph.
994
1025
Args:
@@ -998,7 +1029,7 @@ def make_graph(self, doc, graph_name=None):
998
1029
graph_name = graph_name or self .graph_name
999
1030
self .delete_unused_nodes (self .outputs )
1000
1031
self .topological_sort (self .get_nodes ())
1001
- self .update_proto ()
1032
+ self .update_proto (external_tensor_storage )
1002
1033
1003
1034
# TODO: we'd want to do something like this so that transpose optimizer is active
1004
1035
# for all (unit) tests
@@ -1041,7 +1072,7 @@ def make_graph(self, doc, graph_name=None):
1041
1072
# not to use numpy_helper.from_array to create a new tensor
1042
1073
# because sometimes onnx will have a bug that only check the tensor data in specific field
1043
1074
# such as at upsample it only checks the float_data field.
1044
- t = op .get_attr ( "value" )
1075
+ t = op .get_value_attr ( external_tensor_storage )
1045
1076
tensor = helper .get_attribute_value (t )
1046
1077
tensor .name = op .output [0 ]
1047
1078
initializers .append (tensor )
@@ -1070,14 +1101,14 @@ def make_graph(self, doc, graph_name=None):
1070
1101
1071
1102
return graph
1072
1103
1073
- def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , ** kwargs ):
1104
+ def make_model (self , graph_doc , optimize = False , graph_name = "tf2onnx" , external_tensor_storage = None , ** kwargs ):
1074
1105
"""
1075
1106
Create final ModelProto for onnx from internal graph.
1076
1107
Args:
1077
1108
optimize: optimize graph via onnx
1078
1109
doc: text for doc string of the model
1079
1110
"""
1080
- graph = self .make_graph (graph_doc , graph_name )
1111
+ graph = self .make_graph (graph_doc , graph_name , external_tensor_storage )
1081
1112
1082
1113
if "producer_name" not in kwargs :
1083
1114
kwargs = {"producer_name" : "tf2onnx" ,
0 commit comments