Skip to content

Commit 45c2ec7

Browse files
Created compress_graph_def function
1 parent 0110037 commit 45c2ec7

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

tf2onnx/tf_utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,24 @@ def get_tf_node_attr(node, name):
128128
def get_tf_version():
129129
return LooseVersion(tf.__version__)
130130

131-
132-
def tflist_to_onnx(g, shape_override):
131+
def compress_graph_def(graph_def):
132+
"""
133+
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
134+
"""
135+
node_defs = list(graph_def.node)
136+
const_node_values = {}
137+
for node_def in node_defs:
138+
if node_def.op == 'Const':
139+
tensor = node_def.attr["value"].tensor
140+
# Small constants are sometimes used to store shape information and must be maintained
141+
if len(tensor.tensor_content) > 1000:
142+
make_sure(node_def.name not in const_node_values, "Two nodes in graph have same name %s", node_def.name)
143+
const_node_values[node_def.name] = tensor.tensor_content
144+
tensor.tensor_content = b''
145+
return const_node_values
146+
147+
148+
def tflist_to_onnx(g, shape_override, const_node_values=None):
133149
"""
134150
Convert the tf-node list into an onnx graph with minimal rewrites so
135151
we can use the onnx graph as intermediate graph.
@@ -198,7 +214,10 @@ def tflist_to_onnx(g, shape_override):
198214
attr[a] = nattr.name
199215
functions[nattr.name] = input_shapes
200216
elif a == "value":
201-
onnx_tensor = tf_to_onnx_tensor(get_tf_node_attr(node, a), name=port_name(node.name))
217+
tensor = get_tf_node_attr(node, a)
218+
if const_node_values and node.name in const_node_values:
219+
tensor.tensor_content = const_node_values[node.name]
220+
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
202221
attr[a] = onnx_tensor
203222
elif a == "DstT":
204223
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
@@ -222,8 +241,8 @@ def tflist_to_onnx(g, shape_override):
222241
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
223242

224243

225-
def tensorflow_to_onnx(graph, shape_override):
244+
def tensorflow_to_onnx(graph, shape_override, const_node_values=None):
226245
"""
227246
Load tensorflow graph and do a conversion.
228247
"""
229-
return tflist_to_onnx(graph, shape_override)
248+
return tflist_to_onnx(graph, shape_override, const_node_values)

0 commit comments

Comments
 (0)