@@ -138,15 +138,39 @@ def convert_variables_to_constants_large_model(func):
138
138
return frozen_graph_def
139
139
140
140
141
+ def fix_freezing_errors (graph_def ):
142
+ assign_var_ops = []
143
+ for i in reversed (range (len (graph_def .node ))):
144
+ if graph_def .node [i ].op == "AssignVariableOp" :
145
+ assign_var_ops .append (graph_def .node .pop (i ).name )
146
+ logger .warning ("Removed AssignVariableOp %s" , assign_var_ops [- 1 ])
147
+ names_to_remove = set (assign_var_ops )
148
+ for n in graph_def .node :
149
+ for i in reversed (range (len (n .input ))):
150
+ if n .input [i ].startswith ("^" ) and n .input [i ][1 :] in names_to_remove :
151
+ n .input .pop (i )
152
+ return graph_def
153
+
154
+
141
155
def from_function (func , input_names , output_names , large_model = False ):
142
156
if large_model :
143
157
return convert_variables_to_constants_large_model (func )
144
158
145
- if get_tf_version () < LooseVersion ("2.2" ):
146
- frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
159
+ try :
160
+ if get_tf_version () < LooseVersion ("2.2" ):
161
+ frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
162
+ else :
163
+ frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False , aggressive_inlining = True )
164
+ except ValueError as e :
165
+ if "incompatible with expected resource" in str (e ):
166
+ frozen_func = convert_variables_to_constants_large_model (func )
167
+ logger .warning ("TF freezing failed. Attempting to fix freezing errors." )
168
+ graph_def = fix_freezing_errors (frozen_func )
169
+ else :
170
+ raise e
147
171
else :
148
- frozen_func = convert_variables_to_constants_v2 ( func , lower_control_flow = False , aggressive_inlining = True )
149
- graph_def = frozen_func . graph . as_graph_def ( add_shapes = True )
172
+ graph_def = frozen_func . graph . as_graph_def ( add_shapes = True )
173
+
150
174
# output_names = [i.name for i in frozen_func.outputs]
151
175
with tf .Graph ().as_default () as tf_graph :
152
176
with tf_session (graph = tf_graph ) as sess :
0 commit comments