@@ -128,8 +128,24 @@ def get_tf_node_attr(node, name):
128
128
def get_tf_version ():
129
129
return LooseVersion (tf .__version__ )
130
130
131
-
132
- def tflist_to_onnx (g , shape_override ):
131
+ def compress_graph_def (graph_def ):
132
+ """
133
+ Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
134
+ """
135
+ node_defs = list (graph_def .node )
136
+ const_node_values = {}
137
+ for node_def in node_defs :
138
+ if node_def .op == 'Const' :
139
+ tensor = node_def .attr ["value" ].tensor
140
+ # Small constants are sometimes used to store shape information and must be maintained
141
+ if len (tensor .tensor_content ) > 1000 :
142
+ make_sure (node_def .name not in const_node_values , "Two nodes in graph have same name %s" , node_def .name )
143
+ const_node_values [node_def .name ] = tensor .tensor_content
144
+ tensor .tensor_content = b''
145
+ return const_node_values
146
+
147
+
148
+ def tflist_to_onnx (g , shape_override , const_node_values = None ):
133
149
"""
134
150
Convert the tf-node list into an onnx graph with minimal rewrites so
135
151
we can use the onnx graph as intermediate graph.
@@ -198,7 +214,10 @@ def tflist_to_onnx(g, shape_override):
198
214
attr [a ] = nattr .name
199
215
functions [nattr .name ] = input_shapes
200
216
elif a == "value" :
201
- onnx_tensor = tf_to_onnx_tensor (get_tf_node_attr (node , a ), name = port_name (node .name ))
217
+ tensor = get_tf_node_attr (node , a )
218
+ if const_node_values and node .name in const_node_values :
219
+ tensor .tensor_content = const_node_values [node .name ]
220
+ onnx_tensor = tf_to_onnx_tensor (tensor , name = port_name (node .name ))
202
221
attr [a ] = onnx_tensor
203
222
elif a == "DstT" :
204
223
attr ["to" ] = map_tf_dtype (get_tf_node_attr (node , "DstT" ))
@@ -222,8 +241,8 @@ def tflist_to_onnx(g, shape_override):
222
241
return onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes , functions
223
242
224
243
225
- def tensorflow_to_onnx (graph , shape_override ):
244
+ def tensorflow_to_onnx (graph , shape_override , const_node_values = None ):
226
245
"""
227
246
Load tensorflow graph and do a conversion.
228
247
"""
229
- return tflist_to_onnx (graph , shape_override )
248
+ return tflist_to_onnx (graph , shape_override , const_node_values )
0 commit comments