Skip to content

Commit d927ca0

Browse files
Const fold Ifs and remove keras_learning_phase inputs (#1569)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent f943849 commit d927ca0

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,18 @@ def func():
903903
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}, as_session=True,
904904
premade_placeholders=True, process_args={'ignore_default': [_TFINPUT2]})
905905

906+
def test_fold_cond_keras_learning_phase(self):
907+
# keras_learning_phase can slip into frozen graphs and cause huge inefficiencies with If nodes.
908+
# Should be removed and Ifs folded.
909+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
910+
def func():
911+
x = tf_placeholder(tf.float32, [None, None], name=_TFINPUT)
912+
learning_phase = tf_placeholder_with_default(False, [], name="keras_learning_phase")
913+
y = tf.cond(learning_phase, lambda: x * 2, lambda: x * 3)
914+
return tf.identity(y, name=_TFOUTPUT)
915+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, as_session=True, premade_placeholders=True,
916+
graph_validator=lambda g: check_op_count(g, "If", 0, disabled=False))
917+
906918
@check_onnxruntime_incompatibility("Add")
907919
def test_add_bcast(self):
908920
x1_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ def run(self):
8989
continue
9090

9191
self._cut_off_connection(cond_context)
92-
self._create_if_node(cond_context)
92+
if_node = self._create_if_node(cond_context)
9393
# remove nodes in If branches explicitly
94-
for n in list(cond_context.true_branch_context.nodes) + list(cond_context.false_branch_context.nodes):
95-
self.g.remove_node(n.name)
94+
if if_node is not None:
95+
for n in list(cond_context.true_branch_context.nodes) + list(cond_context.false_branch_context.nodes):
96+
self.g.remove_node(n.name)
9697
logger.debug("cond pre rewrite done")
9798

9899
return self.g.get_nodes()
@@ -136,6 +137,19 @@ def _get_output_shape_dtype(self, cond_context):
136137

137138
def _create_if_node(self, cond_context):
138139
output_shapes, output_dtypes = self._get_output_shape_dtype(cond_context)
140+
pred_node = self.g.get_node_by_output(cond_context.pred_input)
141+
while pred_node.type == "Identity":
142+
pred_node = pred_node.inputs[0]
143+
if pred_node.is_const():
144+
# Constant folding for if node
145+
if pred_node.get_tensor_value():
146+
branch_outputs = cond_context.true_branch_context.output
147+
else:
148+
branch_outputs = cond_context.false_branch_context.output
149+
for merge, out in zip(cond_context.merges, branch_outputs):
150+
self.g.replace_all_inputs(merge.output[0], out)
151+
return None
152+
139153
true_graph = utils.construct_graph_from_nodes(
140154
self.g,
141155
list(cond_context.true_branch_context.nodes),

tf2onnx/tf_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=Non
438438
input_names = []
439439
elif use_default and node.name in use_default:
440440
node_type = 'Identity'
441+
elif node.name.endswith('keras_learning_phase'):
442+
logger.warning("Removing optional input %s that appears to be a keras learning phase parameter. "
443+
"Use --ignore_default to force this into an input.", node.name)
444+
node_type = 'Identity'
441445

442446
if takeit:
443447
try:

0 commit comments

Comments
 (0)