Skip to content

Commit d0ba20e

Browse files
authored
Update parent graph in loop "cond" subgraphs (#2201)
* handle while cond graph subgraph parent update * add while loop cond subgrpah test --------- Signed-off-by: Salvetti, Francesco <[email protected]>
1 parent 25c977c commit d0ba20e

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

tests/test_loops.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from backend_test_base import Tf2OnnxBackendTestBase
1010
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
11-
check_onnxruntime_min_version, check_tfjs_max_version
11+
check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite
1212
from tf2onnx.tf_loader import is_tf2
1313

1414

@@ -302,6 +302,23 @@ def func(i):
302302
output_names_with_port = ["output:0"]
303303
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
304304

305+
@check_tf_min_version("2")
306+
@skip_tflite("cond_graph conversion fails with tflite")
307+
def test_while_loop_cond_subgraphs(self):
308+
# test for while_loop with subgraphs in cond
309+
# Note: this is not working on tf1
310+
def func(x):
311+
x_dim = tf.shape(x)[0]
312+
r = tf.cast(tf.zeros(1), x.dtype)
313+
for i in tf.range(10):
314+
if i == x_dim:
315+
break
316+
r += x[i]
317+
return tf.identity(r, name="output")
318+
input_names_with_port = ["input_1:0"]
319+
feed_dict = {"input_1:0": np.arange(0, 15, dtype=np.int32)}
320+
output_names_with_port = ["output:0"]
321+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
305322

306323
if __name__ == '__main__':
307324
unittest_main()

tf2onnx/onnx_opset/controlflow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,15 @@ def inline_subgraph(parent, g, scope, binding):
730730
for n in g.get_nodes():
731731
dtypes = n.output_dtypes
732732
shapes = n.output_shapes
733-
n.graph = parent
733+
subgraphs = n.get_body_graphs()
734+
735+
n.graph = parent # we must change node graph exactly here so that previous/following code can work
736+
737+
# if n has subgraphs, we need to set the correct parent graph for them
738+
if subgraphs:
739+
for sub_name, sub_graph in subgraphs.items():
740+
n.set_body_graph_as_attr(sub_name, sub_graph)
741+
734742
for name, shape, dtype in zip(n.output, shapes, dtypes):
735743
# FIXME: don't access this directly
736744
parent._output_shapes[name] = shape # pylint: disable=protected-access

0 commit comments

Comments
 (0)