Skip to content

Commit c7231d5

Browse files
Keep inputs of placeholder with default when removing unused nodes (#1510)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ede472a commit c7231d5

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tf2onnx/graph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,17 +1505,17 @@ def extract_sub_graph_nodes(self, outputs_name, input_checker=None, remove_unuse
15051505
a list of nodes
15061506
"""
15071507
res_set = set()
1508-
if not outputs_name:
1509-
return list(res_set)
15101508

1511-
for output in outputs_name:
1509+
outputs_to_keep = list(outputs_name)
1510+
if not remove_unused_inputs:
1511+
# add placeholder nodes even if they are not connected to outputs.
1512+
# placeholder nodes with defaults can have inputs themselves
1513+
outputs_to_keep += [inp.output[0] for inp in self.inputs]
1514+
1515+
for output in outputs_to_keep:
15121516
node = self.get_node_by_output(output, search_in_parent_graphs=False)
15131517
res_set = res_set.union(self._extract_sub_graph_nodes(node, input_checker))
15141518

1515-
if not remove_unused_inputs:
1516-
# add back placeholder nodes if they are not connected to outputs.
1517-
res_set = res_set.union(self.inputs)
1518-
15191519
return list(res_set)
15201520

15211521
def delete_unused_nodes(self, outputs_name):

0 commit comments

Comments
 (0)