Skip to content

Commit d6385a1

Browse files
Fix errors in tf freezing with AssignVariableOp (#1481)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 82a4f69 commit d6385a1

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

tf2onnx/tf_loader.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,39 @@ def convert_variables_to_constants_large_model(func):
138138
return frozen_graph_def
139139

140140

141+
def fix_freezing_errors(graph_def):
142+
assign_var_ops = []
143+
for i in reversed(range(len(graph_def.node))):
144+
if graph_def.node[i].op == "AssignVariableOp":
145+
assign_var_ops.append(graph_def.node.pop(i).name)
146+
logger.warning("Removed AssignVariableOp %s", assign_var_ops[-1])
147+
names_to_remove = set(assign_var_ops)
148+
for n in graph_def.node:
149+
for i in reversed(range(len(n.input))):
150+
if n.input[i].startswith("^") and n.input[i][1:] in names_to_remove:
151+
n.input.pop(i)
152+
return graph_def
153+
154+
141155
def from_function(func, input_names, output_names, large_model=False):
142156
if large_model:
143157
return convert_variables_to_constants_large_model(func)
144158

145-
if get_tf_version() < LooseVersion("2.2"):
146-
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
159+
try:
160+
if get_tf_version() < LooseVersion("2.2"):
161+
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
162+
else:
163+
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True)
164+
except ValueError as e:
165+
if "incompatible with expected resource" in str(e):
166+
frozen_func = convert_variables_to_constants_large_model(func)
167+
logger.warning("TF freezing failed. Attempting to fix freezing errors.")
168+
graph_def = fix_freezing_errors(frozen_func)
169+
else:
170+
raise e
147171
else:
148-
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True)
149-
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
172+
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
173+
150174
# output_names = [i.name for i in frozen_func.outputs]
151175
with tf.Graph().as_default() as tf_graph:
152176
with tf_session(graph=tf_graph) as sess:

0 commit comments

Comments
 (0)