Skip to content

Commit b218213

Browse files
Created compute_const_folding_using_tf
1 parent f5f8b2d commit b218213

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

tf2onnx/tf_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,73 @@ def compress_graph_def(graph_def):
140140
tensor.tensor_content = b''
141141
return const_node_values
142142

143+
def compute_const_folding_using_tf(g, const_node_values):
144+
"""Find nodes with constant inputs and compute their values using TF"""
145+
if const_node_values is None:
146+
const_node_values = {}
147+
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
148+
149+
ops = g.get_operations()
150+
outputs_to_values = {}
151+
outputs_to_dtypes = {}
152+
153+
for node in ops:
154+
# Load values of constants. Use const_node_values if possible
155+
if node.type in ["Const", "ConstV2"]:
156+
tensor = node.node_def.attr["value"].tensor
157+
if node.name in const_node_values:
158+
tensor.tensor_content = const_node_values[node.name]
159+
outputs_to_values[node.outputs[0].name] = get_tf_tensor_data(tensor)
160+
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
161+
162+
unneeded_outputs = set()
163+
progress = True
164+
while progress:
165+
progress = False
166+
for node in ops:
167+
# Find ops with constant inputs and compute their values
168+
input_names = [i.name for i in node.inputs]
169+
output_names = [i.name for i in node.outputs]
170+
can_fold = len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
171+
# We can only fold nodes with a single output
172+
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
173+
# Skip if value already computed, used, and discarded
174+
can_fold = can_fold and output_names[0] not in unneeded_outputs
175+
if can_fold:
176+
g = tf.Graph()
177+
with g.as_default():
178+
for inp in input_names:
179+
tf.compat.v1.placeholder(outputs_to_dtypes[inp], name=inp.split(':')[0])
180+
mini_graph_def = g.as_graph_def()
181+
mini_graph_def.node.append(node.node_def)
182+
tf_reset_default_graph()
183+
feed_dict = {}
184+
for inp in input_names:
185+
feed_dict[inp] = outputs_to_values[inp]
186+
with tf_session() as sess:
187+
tf.import_graph_def(mini_graph_def, name='')
188+
results = sess.run(output_names, feed_dict=feed_dict)
189+
outputs_to_values[output_names[0]] = results[0]
190+
outputs_to_dtypes[output_names[0]] = node.outputs[0].dtype
191+
progress = True
192+
unneeded_outputs.update(outputs_to_values.keys())
193+
for node in ops:
194+
# Mark values we need to keep
195+
input_names = [i.name for i in node.inputs]
196+
output_names = [i.name for i in node.outputs]
197+
if len(output_names) == 1 and output_names[0] in outputs_to_values:
198+
continue
199+
for i in input_names:
200+
if i in unneeded_outputs:
201+
unneeded_outputs.remove(i)
202+
for node in unneeded_outputs:
203+
# Remove unneeded values to prevent memory usage explosion
204+
if node in outputs_to_values:
205+
del outputs_to_values[node]
206+
del outputs_to_dtypes[node]
207+
208+
logger.info("Computed %d values for constant folding", len(outputs_to_values))
209+
return outputs_to_values, outputs_to_dtypes
143210

144211
def tflist_to_onnx(g, shape_override, const_node_values=None):
145212
"""

0 commit comments

Comments
 (0)