@@ -140,6 +140,73 @@ def compress_graph_def(graph_def):
140
140
tensor .tensor_content = b''
141
141
return const_node_values
142
142
143
+ def compute_const_folding_using_tf (g , const_node_values ):
144
+ """Find nodes with constant inputs and compute their values using TF"""
145
+ if const_node_values is None :
146
+ const_node_values = {}
147
+ from tf2onnx .tf_loader import tf_session , tf_reset_default_graph
148
+
149
+ ops = g .get_operations ()
150
+ outputs_to_values = {}
151
+ outputs_to_dtypes = {}
152
+
153
+ for node in ops :
154
+ # Load values of constants. Use const_node_values if possible
155
+ if node .type in ["Const" , "ConstV2" ]:
156
+ tensor = node .node_def .attr ["value" ].tensor
157
+ if node .name in const_node_values :
158
+ tensor .tensor_content = const_node_values [node .name ]
159
+ outputs_to_values [node .outputs [0 ].name ] = get_tf_tensor_data (tensor )
160
+ outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
161
+
162
+ unneeded_outputs = set ()
163
+ progress = True
164
+ while progress :
165
+ progress = False
166
+ for node in ops :
167
+ # Find ops with constant inputs and compute their values
168
+ input_names = [i .name for i in node .inputs ]
169
+ output_names = [i .name for i in node .outputs ]
170
+ can_fold = len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
171
+ # We can only fold nodes with a single output
172
+ can_fold = can_fold and len (output_names ) == 1 and output_names [0 ] not in outputs_to_values
173
+ # Skip if value already computed, used, and discarded
174
+ can_fold = can_fold and output_names [0 ] not in unneeded_outputs
175
+ if can_fold :
176
+ g = tf .Graph ()
177
+ with g .as_default ():
178
+ for inp in input_names :
179
+ tf .compat .v1 .placeholder (outputs_to_dtypes [inp ], name = inp .split (':' )[0 ])
180
+ mini_graph_def = g .as_graph_def ()
181
+ mini_graph_def .node .append (node .node_def )
182
+ tf_reset_default_graph ()
183
+ feed_dict = {}
184
+ for inp in input_names :
185
+ feed_dict [inp ] = outputs_to_values [inp ]
186
+ with tf_session () as sess :
187
+ tf .import_graph_def (mini_graph_def , name = '' )
188
+ results = sess .run (output_names , feed_dict = feed_dict )
189
+ outputs_to_values [output_names [0 ]] = results [0 ]
190
+ outputs_to_dtypes [output_names [0 ]] = node .outputs [0 ].dtype
191
+ progress = True
192
+ unneeded_outputs .update (outputs_to_values .keys ())
193
+ for node in ops :
194
+ # Mark values we need to keep
195
+ input_names = [i .name for i in node .inputs ]
196
+ output_names = [i .name for i in node .outputs ]
197
+ if len (output_names ) == 1 and output_names [0 ] in outputs_to_values :
198
+ continue
199
+ for i in input_names :
200
+ if i in unneeded_outputs :
201
+ unneeded_outputs .remove (i )
202
+ for node in unneeded_outputs :
203
+ # Remove unneeded values to prevent memory usage explosion
204
+ if node in outputs_to_values :
205
+ del outputs_to_values [node ]
206
+ del outputs_to_dtypes [node ]
207
+
208
+ logger .info ("Computed %d values for constant folding" , len (outputs_to_values ))
209
+ return outputs_to_values , outputs_to_dtypes
143
210
144
211
def tflist_to_onnx (g , shape_override , const_node_values = None ):
145
212
"""
0 commit comments