Skip to content

Commit f99785a

Browse files
Merge pull request #1087 from onnx/tom/ConvertLargeModels
Created compress_graph_def function
2 parents 240fae1 + 45c2ec7 commit f99785a

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

tf2onnx/tf_utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,24 @@ def get_tf_node_attr(node, name):
124124
def get_tf_version():
125125
return LooseVersion(tf.__version__)
126126

127-
128-
def tflist_to_onnx(g, shape_override):
127+
def compress_graph_def(graph_def):
128+
"""
129+
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
130+
"""
131+
node_defs = list(graph_def.node)
132+
const_node_values = {}
133+
for node_def in node_defs:
134+
if node_def.op == 'Const':
135+
tensor = node_def.attr["value"].tensor
136+
# Small constants are sometimes used to store shape information and must be maintained
137+
if len(tensor.tensor_content) > 1000:
138+
make_sure(node_def.name not in const_node_values, "Two nodes in graph have same name %s", node_def.name)
139+
const_node_values[node_def.name] = tensor.tensor_content
140+
tensor.tensor_content = b''
141+
return const_node_values
142+
143+
144+
def tflist_to_onnx(g, shape_override, const_node_values=None):
129145
"""
130146
Convert the tf-node list into an onnx graph with minimal rewrites so
131147
we can use the onnx graph as intermediate graph.
@@ -193,7 +209,10 @@ def tflist_to_onnx(g, shape_override):
193209
attr[a] = nattr.name
194210
functions[nattr.name] = input_shapes
195211
elif a == "value":
196-
onnx_tensor = tf_to_onnx_tensor(get_tf_node_attr(node, a), name=port_name(node.name))
212+
tensor = get_tf_node_attr(node, a)
213+
if const_node_values and node.name in const_node_values:
214+
tensor.tensor_content = const_node_values[node.name]
215+
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
197216
attr[a] = onnx_tensor
198217
elif a == "DstT":
199218
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
@@ -217,8 +236,8 @@ def tflist_to_onnx(g, shape_override):
217236
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
218237

219238

220-
def tensorflow_to_onnx(graph, shape_override):
239+
def tensorflow_to_onnx(graph, shape_override, const_node_values=None):
221240
"""
222241
Load tensorflow graph and do a conversion.
223242
"""
224-
return tflist_to_onnx(graph, shape_override)
243+
return tflist_to_onnx(graph, shape_override, const_node_values)

0 commit comments

Comments
 (0)