Skip to content

Commit 86cb02a

Browse files
committed
add tf_graph_tool
1 parent 5810313 commit 86cb02a

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed

tools/tf_graph_tool.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
""" Tool for common tf graph operations. """
5+
6+
from __future__ import division
7+
from __future__ import print_function
8+
9+
import argparse
10+
import logging
11+
import os
12+
import sys
13+
14+
from google.protobuf import text_format
15+
import tensorflow as tf
16+
from tensorflow.python.framework import graph_util
17+
18+
# pylint: disable=missing-docstring
19+
20+
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
21+
22+
23+
def get_file_name(path):
24+
return os.path.basename(path)
25+
26+
27+
def get_file_name_without_ext(path):
28+
return '.'.join(get_file_name(path).split('.')[:-1])
29+
30+
31+
def replace_file_extension(path, ext):
32+
tokens = path.split('.')[:-1]
33+
return '.'.join(tokens.appends(ext))
34+
35+
36+
def append_file_name_suffix(path, suffix):
37+
tokens = path.split('.')
38+
tokens[-2] += '_' + suffix
39+
return '.'.join(tokens)
40+
41+
42+
def get_file_directory(path):
43+
return os.path.dirname(path)
44+
45+
46+
def get_file_directory_name(path):
47+
return os.path.basename(get_file_directory(path))
48+
49+
50+
def create_directory(path):
51+
if not os.path.isdir(path):
52+
os.makedirs(path, exist_ok=True)
53+
54+
55+
def load_graph_def_from_pb(path):
56+
tf.reset_default_graph()
57+
graph_def = tf.GraphDef()
58+
with open(path, "rb") as f:
59+
graph_def.ParseFromString(f.read())
60+
return graph_def
61+
62+
63+
def save_graph_def(graph_def, path, as_text=False):
64+
if as_text:
65+
with open(path, "w") as f:
66+
f.write(text_format.MessageToString(graph_def))
67+
else:
68+
with open(path, "wb") as f:
69+
f.write(graph_def.SerializeToString())
70+
71+
72+
def get_node_name(tensor_name):
73+
if tensor_name.startswith("^"):
74+
return tensor_name[1:]
75+
return tensor_name.split(":")[0]
76+
77+
78+
def get_graph_def_io_nodes(graph_def):
79+
consumed = set()
80+
inputs = []
81+
outputs = []
82+
for node in graph_def.node:
83+
for i in node.input:
84+
consumed.add(get_node_name(i))
85+
if node.op in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]:
86+
inputs.append(node.name)
87+
88+
for node in graph_def.node:
89+
if node.name not in consumed and node.name not in inputs:
90+
outputs.append(node.name)
91+
92+
return inputs, outputs
93+
94+
95+
class main(object):
96+
@staticmethod
97+
def convert_pb_to_pbtxt(input_path, output_path=None):
98+
if not output_path:
99+
output_path = replace_file_extension(input_path, "pbtxt")
100+
101+
logging.info("load from %s", input_path)
102+
graph_def = load_graph_def_from_pb(input_path)
103+
104+
logging.info("save to %s", output_path)
105+
save_graph_def(graph_def, output_path, as_text=True)
106+
107+
@staticmethod
108+
def convert_pb_to_summary(input_path, output_dir=None, start_tensorboard=False, port=6006):
109+
if not output_dir:
110+
output_dir = input_path + ".summary"
111+
112+
logging.info("load from %s", input_path)
113+
graph_def = load_graph_def_from_pb(input_path)
114+
115+
logging.info("save to %s", output_dir)
116+
create_directory(output_dir)
117+
with tf.Session() as sess:
118+
tf.import_graph_def(graph_def, name=get_file_name_without_ext(input_path))
119+
train_writer = tf.summary.FileWriter(output_dir)
120+
train_writer.add_graph(sess.graph)
121+
train_writer.close()
122+
123+
if start_tensorboard:
124+
logging.info("launch tensorboard")
125+
os.system("start tensorboard --logdir {} --port {}".format(output_dir, port))
126+
os.system("start http://localhost:{}".format(port))
127+
128+
@staticmethod
129+
def get_graph_io_nodes(input_path):
130+
logging.info("load from %s", input_path)
131+
graph_def = load_graph_def_from_pb(input_path)
132+
inputs, outputs = get_graph_def_io_nodes(graph_def)
133+
logging.info("graph has:")
134+
logging.info("\t%s inputs: %s", len(inputs), ','.join(inputs))
135+
logging.info("\t%s (possible) outputs: %s", len(inputs), ','.join(outputs))
136+
137+
@staticmethod
138+
def extract_sub_graph(input_path, output_path=None, dest_nodes=None):
139+
if not output_path:
140+
output_path = append_file_name_suffix(input_path, "sub")
141+
142+
logging.info("load from %s", input_path)
143+
graph_def = load_graph_def_from_pb(input_path)
144+
logging.info("\ttotal node = %s", len(graph_def.node))
145+
146+
if dest_nodes:
147+
dest_nodes = dest_nodes.split(',')
148+
else:
149+
_, dest_nodes = get_graph_def_io_nodes(graph_def)
150+
151+
graph_def = graph_util.extract_sub_graph(graph_def, dest_nodes)
152+
logging.info("save to %s", output_path)
153+
logging.info("\ttotal node = %s", len(graph_def.node))
154+
save_graph_def(graph_def, output_path)
155+
156+
157+
if __name__ == "__main__":
158+
parser = argparse.ArgumentParser()
159+
subparsers = parser.add_subparsers()
160+
161+
# pb2txt
162+
subparser = subparsers.add_parser("pb2txt", help="convert pb to pbtxt")
163+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
164+
subparser.add_argument("--output", dest="output_path", help="output pbtxt path")
165+
subparser.set_defaults(func=main.convert_pb_to_pbtxt)
166+
167+
# pb2summary
168+
subparser = subparsers.add_parser("pb2summary", help="create summary from pb")
169+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
170+
subparser.add_argument("--output", dest="output_dir", help="output summary directory")
171+
subparser.add_argument("--tb", dest="start_tensorboard", action="store_true", default=False,
172+
help="open with tensorboard")
173+
subparser.add_argument("--port", type=int, help="tensorboard port")
174+
subparser.set_defaults(func=main.convert_pb_to_summary)
175+
176+
# io
177+
subparser = subparsers.add_parser("io", help="get input nodes for graph, guess output nodes")
178+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
179+
subparser.set_defaults(func=main.get_graph_io_nodes)
180+
181+
# extract
182+
subparser = subparsers.add_parser("extract", help="extract sub-graph")
183+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
184+
subparser.add_argument("--output", dest="output_path", help="output pb path")
185+
subparser.add_argument("--dest_nodes", help="dest nodes")
186+
subparser.set_defaults(func=main.extract_sub_graph)
187+
188+
if len(sys.argv) <= 2:
189+
parser.print_help()
190+
sys.exit()
191+
192+
(args, unknown) = parser.parse_known_args()
193+
194+
func = args.func
195+
del args.func
196+
197+
args = dict(filter(lambda x: x[1], vars(args).items()))
198+
func(**args)

0 commit comments

Comments
 (0)