Skip to content

Commit fa8fc63

Browse files
Created fold_constants_using_tf
1 parent b218213 commit fa8fc63

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

tf2onnx/tf_utils.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def compute_const_folding_using_tf(g, const_node_values):
144144
"""Find nodes with constant inputs and compute their values using TF"""
145145
if const_node_values is None:
146146
const_node_values = {}
147-
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
147+
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
148148

149149
ops = g.get_operations()
150150
outputs_to_values = {}
@@ -167,28 +167,34 @@ def compute_const_folding_using_tf(g, const_node_values):
167167
# Find ops with constant inputs and compute their values
168168
input_names = [i.name for i in node.inputs]
169169
output_names = [i.name for i in node.outputs]
170-
can_fold = len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
170+
can_fold = node.type not in ['Enter']
171+
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
171172
# We can only fold nodes with a single output
172173
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
173174
# Skip if value already computed, used, and discarded
174175
can_fold = can_fold and output_names[0] not in unneeded_outputs
175176
if can_fold:
176-
g = tf.Graph()
177-
with g.as_default():
177+
# Make a mini graph containing just the node to fold
178+
g2 = tf.Graph()
179+
with g2.as_default():
178180
for inp in input_names:
179-
tf.compat.v1.placeholder(outputs_to_dtypes[inp], name=inp.split(':')[0])
180-
mini_graph_def = g.as_graph_def()
181-
mini_graph_def.node.append(node.node_def)
182-
tf_reset_default_graph()
183-
feed_dict = {}
184-
for inp in input_names:
185-
feed_dict[inp] = outputs_to_values[inp]
186-
with tf_session() as sess:
187-
tf.import_graph_def(mini_graph_def, name='')
188-
results = sess.run(output_names, feed_dict=feed_dict)
189-
outputs_to_values[output_names[0]] = results[0]
190-
outputs_to_dtypes[output_names[0]] = node.outputs[0].dtype
191-
progress = True
181+
tf_placeholder(outputs_to_dtypes[inp], name=inp.split(':')[0])
182+
mini_graph_def = g2.as_graph_def()
183+
mini_graph_def.node.append(node.node_def)
184+
g3 = tf.Graph()
185+
with g3.as_default():
186+
feed_dict = {}
187+
for inp in input_names:
188+
feed_dict[inp] = outputs_to_values[inp]
189+
try:
190+
with tf_session() as sess:
191+
tf.import_graph_def(mini_graph_def, name='')
192+
results = sess.run(output_names, feed_dict=feed_dict)
193+
outputs_to_values[output_names[0]] = results[0]
194+
outputs_to_dtypes[output_names[0]] = node.outputs[0].dtype
195+
progress = True
196+
except Exception: # pylint: disable=broad-except
197+
logger.debug("Could not fold node %s", node.name)
192198
unneeded_outputs.update(outputs_to_values.keys())
193199
for node in ops:
194200
# Mark values we need to keep
@@ -205,6 +211,12 @@ def compute_const_folding_using_tf(g, const_node_values):
205211
del outputs_to_values[node]
206212
del outputs_to_dtypes[node]
207213

214+
for node in ops:
215+
# We don't need the constants any more
216+
if node.type in ["Const", "ConstV2"] and node.outputs[0].name in outputs_to_values:
217+
del outputs_to_values[node.outputs[0].name]
218+
del outputs_to_dtypes[node.outputs[0].name]
219+
208220
logger.info("Computed %d values for constant folding", len(outputs_to_values))
209221
return outputs_to_values, outputs_to_dtypes
210222

tf2onnx/tfonnx.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
2424
from tf2onnx.shape_inference import infer_shape
2525
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
26-
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version
26+
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
2727

2828
from . import constants, logging, schemas, utils, handler
2929

@@ -33,6 +33,35 @@
3333
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
3434
# pylint: disable=unused-variable
3535

36+
def fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes):
37+
ops = g.get_nodes()
38+
# pylint: disable=too-many-nested-blocks
39+
keep_looking = True
40+
while keep_looking:
41+
keep_looking = False
42+
for idx, op in enumerate(ops):
43+
if op.output and op.output[0] in outputs_to_values:
44+
logger.info("folding node using tf type=%s, name=%s" % (op.type, op.name))
45+
val = outputs_to_values[op.output[0]]
46+
47+
new_node_name = utils.make_name(op.name)
48+
new_output_name = new_node_name
49+
old_output_name = op.output[0]
50+
old_node_name = op.name
51+
logger.debug("create const node [%s] replacing [%s]", new_node_name, old_node_name)
52+
ops[idx] = g.make_const(new_node_name, val)
53+
54+
logger.debug("replace old output [%s] with new output [%s]", old_output_name, new_output_name)
55+
# need to re-write the consumers input name to use the const name
56+
consumers = g.find_output_consumers(old_output_name)
57+
if consumers:
58+
for consumer in consumers:
59+
g.replace_input(consumer, old_output_name, new_output_name)
60+
61+
# keep looking until there is nothing we can fold.
62+
keep_looking = True
63+
64+
g.reset_nodes(ops)
3665

3766
def rewrite_constant_fold(g, ops):
3867
"""
@@ -378,6 +407,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
378407
if target is None:
379408
target = constants.DEFAULT_TARGET
380409

410+
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values)
411+
381412
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
382413
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)
383414
if not is_subgraph:
@@ -451,6 +482,8 @@ def compat_handler(ctx, node, **kwargs):
451482
if inputs_as_nchw:
452483
transpose_inputs(g, inputs_as_nchw)
453484

485+
fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes)
486+
454487
# pre-processing graph rewrites
455488
# bi-directional re-writer should be placed after single directional re-writer
456489
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,

0 commit comments

Comments
 (0)