Skip to content

Commit f832f96

Browse files
author
wayuanho
authored
Merge pull request #654 from lucienwang1009/extract_subgraph
enhance extracting subgraph
2 parents 0c0e839 + da02842 commit f832f96

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tools/tf_graph_tool.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ def bfs_for_reachable_nodes(target_nodes, name_to_input_name, checker=None):
246246
src_ops = []
247247
def node_checker(n):
248248
if not n.startswith(name_prefix) or n in src_nodes:
249-
src_ops.append(name_to_node[n])
249+
if name_to_node[n] not in src_ops:
250+
src_ops.append(name_to_node[n])
250251
return False
251252
return True
252253
nodes_to_keep = bfs_for_reachable_nodes(dest_nodes, name_to_input_name, checker=node_checker)
@@ -265,10 +266,17 @@ def node_checker(n):
265266
placeholder_node = node_def_pb2.NodeDef()
266267
placeholder_node.op = "Placeholder"
267268
placeholder_node.name = op.name
269+
dtype = None
268270
if str(op.attr["dtype"]):
269-
placeholder_node.attr["dtype"].CopyFrom(op.attr["dtype"])
271+
dtype = op.attr["dtype"]
270272
elif str(op.attr["T"]):
271-
placeholder_node.attr["dtype"].CopyFrom(op.attr["T"])
273+
dtype = op.attr["T"]
274+
elif str(op.attr["output_types"]):
275+
dtype = attr_value_pb2.AttrValue()
276+
dtype.type = op.attr["output_types"].list.type[0]
277+
if dtype is None:
278+
raise RuntimeError("Cannot find dtype for Placeholder: {}".format(op.name))
279+
placeholder_node.attr["dtype"].CopyFrom(dtype)
272280
shape = graph_util.tensor_shape_from_node_def_name(tf_graph, op.name)
273281
placeholder_node.attr["shape"].CopyFrom(
274282
attr_value_pb2.AttrValue(shape=shape.as_proto())

0 commit comments

Comments
 (0)