Skip to content

Commit f2dc6d8

Browse files
authored
Merge pull request #397 from pengwa/fix_issue_
fix an issue when subgraph output node has more than one output
2 parents efb4bce + 4b0b190 commit f2dc6d8

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

tests/test_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,12 @@ def test_split(self):
883883
_ = tf.identity(x_, name=_TFOUTPUT)
884884
self._run_test_case([_OUTPUT], {_INPUT: x_val})
885885

886+
def test_split_with_more_outputs(self):
887+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape(5, 30)
888+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
889+
_, _, _ = tf.split(x0, [4, 15, 11], 1, name="split_test")
890+
self._run_test_case(["split_test:0", "split_test:1", "split_test:2"], {_INPUT: x_val})
891+
886892
def test_reducesum(self):
887893
# not supported by onnx-caffe2
888894
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/graph.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -357,27 +357,15 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
357357
self.reset_nodes(ops)
358358

359359
# add identity node after each output, in case it is renamed during conversion.
360-
nodes_seen = set()
361-
multi_output_nodes = set()
362360
for o in self.outputs:
363361
n = self.get_node_by_output_in_current_graph(o)
364-
if n in nodes_seen:
365-
multi_output_nodes.add(n)
366-
else:
367-
nodes_seen.add(n)
368-
369-
for o in self.outputs:
370-
n = self.get_node_by_output_in_current_graph(o)
371-
# TODO: below doesn't work for nodes with multiple outputs. A work around, keep those intact.
372-
if n in multi_output_nodes:
373-
continue
374362
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
375363
n_shapes = n.output_shapes
376364
n_dtypes = n.output_dtypes
377365
body_graphs = n.graph.contained_graphs.pop(n.name, None)
378366
self.remove_node(n.name)
379367

380-
new_outputs = [o if o != output else new_output_name for output in n.output]
368+
new_outputs = [output if output != o else new_output_name for output in n.output]
381369
# domain should be passed to new node
382370
new_node = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
383371
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,

0 commit comments

Comments
 (0)