Skip to content

Commit 8c7ee10

Browse files
committed
fix ops having more than one output as graph outputs
1 parent c4bce3a commit 8c7ee10

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

tests/test_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,12 @@ def test_split(self):
845845
_ = tf.identity(x_, name=_TFOUTPUT)
846846
self._run_test_case([_OUTPUT], {_INPUT: x_val})
847847

848+
def test_split_with_more_outputs(self):
849+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape(5, 30)
850+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
851+
x_, y, z = tf.split(x0, [4, 15, 11], 1, name="split_test")
852+
self._run_test_case(["split_test:0", "split_test:1", "split_test:2"], {_INPUT: x_val})
853+
848854
def test_reducesum(self):
849855
# not supported by onnx-caffe2
850856
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/graph.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,20 +349,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
349349
self.reset_nodes(ops)
350350

351351
# add identity node after each output, in case it is renamed during conversion.
352-
nodes_seen = set()
353-
multi_output_nodes = set()
354352
for o in self.outputs:
355353
n = self.get_node_by_output_in_current_graph(o)
356-
if n in nodes_seen:
357-
multi_output_nodes.add(n)
358-
else:
359-
nodes_seen.add(n)
360-
361-
for o in self.outputs:
362-
n = self.get_node_by_output_in_current_graph(o)
363-
# TODO: below doesn't work for nodes with multiple outputs. A work around, keep those intact.
364-
if n in multi_output_nodes:
365-
continue
366354
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
367355
n_shapes = n.output_shapes
368356
n_dtypes = n.output_dtypes

0 commit comments

Comments
 (0)