@@ -140,6 +140,85 @@ def compress_graph_def(graph_def):
140
140
tensor .tensor_content = b''
141
141
return const_node_values
142
142
143
+ def compute_const_folding_using_tf (g , const_node_values ):
144
+ """Find nodes with constant inputs and compute their values using TF"""
145
+ if const_node_values is None :
146
+ const_node_values = {}
147
+ from tf2onnx .tf_loader import tf_session , tf_placeholder # pylint: disable=import-outside-toplevel
148
+
149
+ ops = g .get_operations ()
150
+ outputs_to_values = {}
151
+ outputs_to_dtypes = {}
152
+
153
+ for node in ops :
154
+ # Load values of constants. Use const_node_values if possible
155
+ if node .type in ["Const" , "ConstV2" ]:
156
+ tensor = node .node_def .attr ["value" ].tensor
157
+ if node .name in const_node_values :
158
+ tensor .tensor_content = const_node_values [node .name ]
159
+ outputs_to_values [node .outputs [0 ].name ] = get_tf_tensor_data (tensor )
160
+ outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
161
+
162
+ unneeded_outputs = set ()
163
+ progress = True
164
+ while progress :
165
+ progress = False
166
+ for node in ops :
167
+ # Find ops with constant inputs and compute their values
168
+ input_names = [i .name for i in node .inputs ]
169
+ output_names = [i .name for i in node .outputs ]
170
+ can_fold = node .type not in ['Enter' ]
171
+ can_fold = can_fold and len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
172
+ # We can only fold nodes with a single output
173
+ can_fold = can_fold and len (output_names ) == 1 and output_names [0 ] not in outputs_to_values
174
+ # Skip if value already computed, used, and discarded
175
+ can_fold = can_fold and output_names [0 ] not in unneeded_outputs
176
+ if can_fold :
177
+ # Make a mini graph containing just the node to fold
178
+ g2 = tf .Graph ()
179
+ with g2 .as_default ():
180
+ for inp in input_names :
181
+ tf_placeholder (outputs_to_dtypes [inp ], name = inp .split (':' )[0 ])
182
+ mini_graph_def = g2 .as_graph_def ()
183
+ mini_graph_def .node .append (node .node_def )
184
+ g3 = tf .Graph ()
185
+ with g3 .as_default ():
186
+ feed_dict = {}
187
+ for inp in input_names :
188
+ feed_dict [inp ] = outputs_to_values [inp ]
189
+ try :
190
+ with tf_session () as sess :
191
+ tf .import_graph_def (mini_graph_def , name = '' )
192
+ results = sess .run (output_names , feed_dict = feed_dict )
193
+ outputs_to_values [output_names [0 ]] = results [0 ]
194
+ outputs_to_dtypes [output_names [0 ]] = node .outputs [0 ].dtype
195
+ progress = True
196
+ except Exception : # pylint: disable=broad-except
197
+ logger .debug ("Could not fold node %s" , node .name )
198
+ unneeded_outputs .update (outputs_to_values .keys ())
199
+ for node in ops :
200
+ # Mark values we need to keep
201
+ input_names = [i .name for i in node .inputs ]
202
+ output_names = [i .name for i in node .outputs ]
203
+ if len (output_names ) == 1 and output_names [0 ] in outputs_to_values :
204
+ continue
205
+ for i in input_names :
206
+ if i in unneeded_outputs :
207
+ unneeded_outputs .remove (i )
208
+ for node in unneeded_outputs :
209
+ # Remove unneeded values to prevent memory usage explosion
210
+ if node in outputs_to_values :
211
+ del outputs_to_values [node ]
212
+ del outputs_to_dtypes [node ]
213
+
214
+ for node in ops :
215
+ # We don't need the constants any more
216
+ if node .type in ["Const" , "ConstV2" ] and node .outputs [0 ].name in outputs_to_values :
217
+ del outputs_to_values [node .outputs [0 ].name ]
218
+ del outputs_to_dtypes [node .outputs [0 ].name ]
219
+
220
+ logger .info ("Computed %d values for constant folding" , len (outputs_to_values ))
221
+ return outputs_to_values , outputs_to_dtypes
143
222
144
223
def tflist_to_onnx (g , shape_override , const_node_values = None ):
145
224
"""
0 commit comments