Skip to content

Commit 51767eb

Browse files
Merge pull request #1105 from onnx/tom/LargeModelConstFolding
Added constant folding using TF for large models
2 parents 353e46f + fa8fc63 commit 51767eb

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

tf2onnx/tf_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,85 @@ def compress_graph_def(graph_def):
140140
tensor.tensor_content = b''
141141
return const_node_values
142142

143+
def compute_const_folding_using_tf(g, const_node_values):
144+
"""Find nodes with constant inputs and compute their values using TF"""
145+
if const_node_values is None:
146+
const_node_values = {}
147+
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
148+
149+
ops = g.get_operations()
150+
outputs_to_values = {}
151+
outputs_to_dtypes = {}
152+
153+
for node in ops:
154+
# Load values of constants. Use const_node_values if possible
155+
if node.type in ["Const", "ConstV2"]:
156+
tensor = node.node_def.attr["value"].tensor
157+
if node.name in const_node_values:
158+
tensor.tensor_content = const_node_values[node.name]
159+
outputs_to_values[node.outputs[0].name] = get_tf_tensor_data(tensor)
160+
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
161+
162+
unneeded_outputs = set()
163+
progress = True
164+
while progress:
165+
progress = False
166+
for node in ops:
167+
# Find ops with constant inputs and compute their values
168+
input_names = [i.name for i in node.inputs]
169+
output_names = [i.name for i in node.outputs]
170+
can_fold = node.type not in ['Enter']
171+
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
172+
# We can only fold nodes with a single output
173+
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
174+
# Skip if value already computed, used, and discarded
175+
can_fold = can_fold and output_names[0] not in unneeded_outputs
176+
if can_fold:
177+
# Make a mini graph containing just the node to fold
178+
g2 = tf.Graph()
179+
with g2.as_default():
180+
for inp in input_names:
181+
tf_placeholder(outputs_to_dtypes[inp], name=inp.split(':')[0])
182+
mini_graph_def = g2.as_graph_def()
183+
mini_graph_def.node.append(node.node_def)
184+
g3 = tf.Graph()
185+
with g3.as_default():
186+
feed_dict = {}
187+
for inp in input_names:
188+
feed_dict[inp] = outputs_to_values[inp]
189+
try:
190+
with tf_session() as sess:
191+
tf.import_graph_def(mini_graph_def, name='')
192+
results = sess.run(output_names, feed_dict=feed_dict)
193+
outputs_to_values[output_names[0]] = results[0]
194+
outputs_to_dtypes[output_names[0]] = node.outputs[0].dtype
195+
progress = True
196+
except Exception: # pylint: disable=broad-except
197+
logger.debug("Could not fold node %s", node.name)
198+
unneeded_outputs.update(outputs_to_values.keys())
199+
for node in ops:
200+
# Mark values we need to keep
201+
input_names = [i.name for i in node.inputs]
202+
output_names = [i.name for i in node.outputs]
203+
if len(output_names) == 1 and output_names[0] in outputs_to_values:
204+
continue
205+
for i in input_names:
206+
if i in unneeded_outputs:
207+
unneeded_outputs.remove(i)
208+
for node in unneeded_outputs:
209+
# Remove unneeded values to prevent memory usage explosion
210+
if node in outputs_to_values:
211+
del outputs_to_values[node]
212+
del outputs_to_dtypes[node]
213+
214+
for node in ops:
215+
# We don't need the constants any more
216+
if node.type in ["Const", "ConstV2"] and node.outputs[0].name in outputs_to_values:
217+
del outputs_to_values[node.outputs[0].name]
218+
del outputs_to_dtypes[node.outputs[0].name]
219+
220+
logger.info("Computed %d values for constant folding", len(outputs_to_values))
221+
return outputs_to_values, outputs_to_dtypes
143222

144223
def tflist_to_onnx(g, shape_override, const_node_values=None):
145224
"""

tf2onnx/tfonnx.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
2424
from tf2onnx.shape_inference import infer_shape
2525
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
26-
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version
26+
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
2727

2828
from . import constants, logging, schemas, utils, handler
2929

@@ -33,6 +33,35 @@
3333
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
3434
# pylint: disable=unused-variable
3535

36+
def fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes):
37+
ops = g.get_nodes()
38+
# pylint: disable=too-many-nested-blocks
39+
keep_looking = True
40+
while keep_looking:
41+
keep_looking = False
42+
for idx, op in enumerate(ops):
43+
if op.output and op.output[0] in outputs_to_values:
44+
logger.info("folding node using tf type=%s, name=%s" % (op.type, op.name))
45+
val = outputs_to_values[op.output[0]]
46+
47+
new_node_name = utils.make_name(op.name)
48+
new_output_name = new_node_name
49+
old_output_name = op.output[0]
50+
old_node_name = op.name
51+
logger.debug("create const node [%s] replacing [%s]", new_node_name, old_node_name)
52+
ops[idx] = g.make_const(new_node_name, val)
53+
54+
logger.debug("replace old output [%s] with new output [%s]", old_output_name, new_output_name)
55+
# need to re-write the consumers input name to use the const name
56+
consumers = g.find_output_consumers(old_output_name)
57+
if consumers:
58+
for consumer in consumers:
59+
g.replace_input(consumer, old_output_name, new_output_name)
60+
61+
# keep looking until there is nothing we can fold.
62+
keep_looking = True
63+
64+
g.reset_nodes(ops)
3665

3766
def rewrite_constant_fold(g, ops):
3867
"""
@@ -378,6 +407,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
378407
if target is None:
379408
target = constants.DEFAULT_TARGET
380409

410+
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values)
411+
381412
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
382413
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)
383414
if not is_subgraph:
@@ -451,6 +482,8 @@ def compat_handler(ctx, node, **kwargs):
451482
if inputs_as_nchw:
452483
transpose_inputs(g, inputs_as_nchw)
453484

485+
fold_constants_using_tf(g, outputs_to_values, outputs_to_dtypes)
486+
454487
# pre-processing graph rewrites
455488
# bi-directional re-writer should be placed after single directional re-writer
456489
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,

0 commit comments

Comments
 (0)