@@ -144,7 +144,7 @@ def compute_const_folding_using_tf(g, const_node_values):
144
144
"""Find nodes with constant inputs and compute their values using TF"""
145
145
if const_node_values is None :
146
146
const_node_values = {}
147
- from tf2onnx .tf_loader import tf_session , tf_reset_default_graph
147
+ from tf2onnx .tf_loader import tf_session , tf_placeholder # pylint: disable=import-outside-toplevel
148
148
149
149
ops = g .get_operations ()
150
150
outputs_to_values = {}
@@ -167,28 +167,34 @@ def compute_const_folding_using_tf(g, const_node_values):
167
167
# Find ops with constant inputs and compute their values
168
168
input_names = [i .name for i in node .inputs ]
169
169
output_names = [i .name for i in node .outputs ]
170
- can_fold = len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
170
+ can_fold = node .type not in ['Enter' ]
171
+ can_fold = can_fold and len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
171
172
# We can only fold nodes with a single output
172
173
can_fold = can_fold and len (output_names ) == 1 and output_names [0 ] not in outputs_to_values
173
174
# Skip if value already computed, used, and discarded
174
175
can_fold = can_fold and output_names [0 ] not in unneeded_outputs
175
176
if can_fold :
176
- g = tf .Graph ()
177
- with g .as_default ():
177
+ # Make a mini graph containing just the node to fold
178
+ g2 = tf .Graph ()
179
+ with g2 .as_default ():
178
180
for inp in input_names :
179
- tf .compat .v1 .placeholder (outputs_to_dtypes [inp ], name = inp .split (':' )[0 ])
180
- mini_graph_def = g .as_graph_def ()
181
- mini_graph_def .node .append (node .node_def )
182
- tf_reset_default_graph ()
183
- feed_dict = {}
184
- for inp in input_names :
185
- feed_dict [inp ] = outputs_to_values [inp ]
186
- with tf_session () as sess :
187
- tf .import_graph_def (mini_graph_def , name = '' )
188
- results = sess .run (output_names , feed_dict = feed_dict )
189
- outputs_to_values [output_names [0 ]] = results [0 ]
190
- outputs_to_dtypes [output_names [0 ]] = node .outputs [0 ].dtype
191
- progress = True
181
+ tf_placeholder (outputs_to_dtypes [inp ], name = inp .split (':' )[0 ])
182
+ mini_graph_def = g2 .as_graph_def ()
183
+ mini_graph_def .node .append (node .node_def )
184
+ g3 = tf .Graph ()
185
+ with g3 .as_default ():
186
+ feed_dict = {}
187
+ for inp in input_names :
188
+ feed_dict [inp ] = outputs_to_values [inp ]
189
+ try :
190
+ with tf_session () as sess :
191
+ tf .import_graph_def (mini_graph_def , name = '' )
192
+ results = sess .run (output_names , feed_dict = feed_dict )
193
+ outputs_to_values [output_names [0 ]] = results [0 ]
194
+ outputs_to_dtypes [output_names [0 ]] = node .outputs [0 ].dtype
195
+ progress = True
196
+ except Exception : # pylint: disable=broad-except
197
+ logger .debug ("Could not fold node %s" , node .name )
192
198
unneeded_outputs .update (outputs_to_values .keys ())
193
199
for node in ops :
194
200
# Mark values we need to keep
@@ -205,6 +211,12 @@ def compute_const_folding_using_tf(g, const_node_values):
205
211
del outputs_to_values [node ]
206
212
del outputs_to_dtypes [node ]
207
213
214
+ for node in ops :
215
+ # We don't need the constants any more
216
+ if node .type in ["Const" , "ConstV2" ] and node .outputs [0 ].name in outputs_to_values :
217
+ del outputs_to_values [node .outputs [0 ].name ]
218
+ del outputs_to_dtypes [node .outputs [0 ].name ]
219
+
208
220
logger .info ("Computed %d values for constant folding" , len (outputs_to_values ))
209
221
return outputs_to_values , outputs_to_dtypes
210
222
0 commit comments