@@ -124,8 +124,24 @@ def get_tf_node_attr(node, name):
124
124
def get_tf_version ():
125
125
return LooseVersion (tf .__version__ )
126
126
127
-
128
- def tflist_to_onnx (g , shape_override ):
127
+ def compress_graph_def (graph_def ):
128
+ """
129
+ Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
130
+ """
131
+ node_defs = list (graph_def .node )
132
+ const_node_values = {}
133
+ for node_def in node_defs :
134
+ if node_def .op == 'Const' :
135
+ tensor = node_def .attr ["value" ].tensor
136
+ # Small constants are sometimes used to store shape information and must be maintained
137
+ if len (tensor .tensor_content ) > 1000 :
138
+ make_sure (node_def .name not in const_node_values , "Two nodes in graph have same name %s" , node_def .name )
139
+ const_node_values [node_def .name ] = tensor .tensor_content
140
+ tensor .tensor_content = b''
141
+ return const_node_values
142
+
143
+
144
+ def tflist_to_onnx (g , shape_override , const_node_values = None ):
129
145
"""
130
146
Convert the tf-node list into an onnx graph with minimal rewrites so
131
147
we can use the onnx graph as intermediate graph.
@@ -193,7 +209,10 @@ def tflist_to_onnx(g, shape_override):
193
209
attr [a ] = nattr .name
194
210
functions [nattr .name ] = input_shapes
195
211
elif a == "value" :
196
- onnx_tensor = tf_to_onnx_tensor (get_tf_node_attr (node , a ), name = port_name (node .name ))
212
+ tensor = get_tf_node_attr (node , a )
213
+ if const_node_values and node .name in const_node_values :
214
+ tensor .tensor_content = const_node_values [node .name ]
215
+ onnx_tensor = tf_to_onnx_tensor (tensor , name = port_name (node .name ))
197
216
attr [a ] = onnx_tensor
198
217
elif a == "DstT" :
199
218
attr ["to" ] = map_tf_dtype (get_tf_node_attr (node , "DstT" ))
@@ -217,8 +236,8 @@ def tflist_to_onnx(g, shape_override):
217
236
return onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , functions
218
237
219
238
220
- def tensorflow_to_onnx (graph , shape_override ):
239
+ def tensorflow_to_onnx (graph , shape_override , const_node_values = None ):
221
240
"""
222
241
Load tensorflow graph and do a conversion.
223
242
"""
224
- return tflist_to_onnx (graph , shape_override )
243
+ return tflist_to_onnx (graph , shape_override , const_node_values )
0 commit comments