Skip to content

Commit bee1f80

Browse files
authored
Merge pull request #335 from nbcsm/tf_graph_tool
add tf_graph_tool
2 parents 9b3ae10 + 1ffbc16 commit bee1f80

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

tools/tf_graph_tool.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
""" Tool for common tf graph operations. """
5+
6+
from __future__ import division
7+
from __future__ import print_function
8+
9+
import argparse
10+
from collections import Counter
11+
import logging
12+
import os
13+
import sys
14+
15+
from google.protobuf import text_format
16+
import tensorflow as tf
17+
from tensorflow.python.framework import graph_util
18+
19+
# pylint: disable=missing-docstring
20+
21+
logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO)
22+
23+
24+
def get_file_name(path):
25+
return os.path.basename(path)
26+
27+
28+
def get_file_name_without_ext(path):
29+
return '.'.join(get_file_name(path).split('.')[:-1])
30+
31+
32+
def replace_file_extension(path, ext):
33+
tokens = path.split('.')[:-1]
34+
return '.'.join(tokens.appends(ext))
35+
36+
37+
def append_file_name_suffix(path, suffix):
38+
tokens = path.split('.')
39+
tokens[-2] += '_' + suffix
40+
return '.'.join(tokens)
41+
42+
43+
def get_file_directory(path):
44+
return os.path.dirname(path)
45+
46+
47+
def get_file_directory_name(path):
48+
return os.path.basename(get_file_directory(path))
49+
50+
51+
def create_directory(path):
52+
if not os.path.isdir(path):
53+
os.makedirs(path, exist_ok=True)
54+
55+
56+
def load_graph_def_from_pb(path):
57+
tf.reset_default_graph()
58+
graph_def = tf.GraphDef()
59+
with open(path, "rb") as f:
60+
graph_def.ParseFromString(f.read())
61+
return graph_def
62+
63+
64+
def save_graph_def(graph_def, path, as_text=False):
65+
if as_text:
66+
with open(path, "w") as f:
67+
f.write(text_format.MessageToString(graph_def))
68+
else:
69+
with open(path, "wb") as f:
70+
f.write(graph_def.SerializeToString())
71+
72+
73+
def get_node_name(tensor_name):
74+
if tensor_name.startswith("^"):
75+
return tensor_name[1:]
76+
return tensor_name.split(":")[0]
77+
78+
79+
def get_graph_def_io_nodes(graph_def):
80+
consumed = set()
81+
inputs = []
82+
outputs = []
83+
for node in graph_def.node:
84+
for i in node.input:
85+
consumed.add(get_node_name(i))
86+
if node.op in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]:
87+
inputs.append(node.name)
88+
89+
for node in graph_def.node:
90+
if node.name not in consumed and node.name not in inputs:
91+
outputs.append(node.name)
92+
93+
return inputs, outputs
94+
95+
96+
class main(object):
97+
@staticmethod
98+
def convert_pb_to_pbtxt(input_path, output_path=None):
99+
if not output_path:
100+
output_path = replace_file_extension(input_path, "pbtxt")
101+
102+
logging.info("load from %s", input_path)
103+
graph_def = load_graph_def_from_pb(input_path)
104+
105+
logging.info("save to %s", output_path)
106+
save_graph_def(graph_def, output_path, as_text=True)
107+
108+
@staticmethod
109+
def convert_pb_to_summary(input_path, output_dir=None, start_tensorboard=False, port=6006):
110+
if not output_dir:
111+
output_dir = input_path + ".summary"
112+
113+
logging.info("load from %s", input_path)
114+
graph_def = load_graph_def_from_pb(input_path)
115+
116+
logging.info("save to %s", output_dir)
117+
create_directory(output_dir)
118+
with tf.Session() as sess:
119+
tf.import_graph_def(graph_def, name=get_file_name_without_ext(input_path))
120+
train_writer = tf.summary.FileWriter(output_dir)
121+
train_writer.add_graph(sess.graph)
122+
train_writer.close()
123+
124+
if start_tensorboard:
125+
logging.info("launch tensorboard")
126+
os.system("start tensorboard --logdir {} --port {}".format(output_dir, port))
127+
os.system("start http://localhost:{}".format(port))
128+
129+
@staticmethod
130+
def get_graph_io_nodes(input_path):
131+
logging.info("load from %s", input_path)
132+
graph_def = load_graph_def_from_pb(input_path)
133+
inputs, outputs = get_graph_def_io_nodes(graph_def)
134+
logging.info("graph has:")
135+
logging.info("\t%s inputs: %s", len(inputs), ','.join(inputs))
136+
logging.info("\t%s (possible) outputs: %s", len(outputs), ','.join(outputs))
137+
138+
@staticmethod
139+
def print_graph_stat(input_path):
140+
logging.info("load from %s", input_path)
141+
graph_def = load_graph_def_from_pb(input_path)
142+
143+
op_stat = Counter()
144+
for node in graph_def.node:
145+
op_stat[node.op] += 1
146+
147+
logging.info("graph stat:")
148+
for op, count in sorted(op_stat.items(), key=lambda x: x[0]):
149+
logging.info("\t%s = %s", op, count)
150+
151+
@staticmethod
152+
def extract_sub_graph(input_path, output_path=None, dest_nodes=None):
153+
if not output_path:
154+
output_path = append_file_name_suffix(input_path, "sub")
155+
156+
logging.info("load from %s", input_path)
157+
graph_def = load_graph_def_from_pb(input_path)
158+
logging.info("\ttotal node = %s", len(graph_def.node))
159+
160+
if dest_nodes:
161+
dest_nodes = dest_nodes.split(',')
162+
else:
163+
_, dest_nodes = get_graph_def_io_nodes(graph_def)
164+
165+
graph_def = graph_util.extract_sub_graph(graph_def, dest_nodes)
166+
logging.info("save to %s", output_path)
167+
logging.info("\ttotal node = %s", len(graph_def.node))
168+
save_graph_def(graph_def, output_path)
169+
170+
171+
if __name__ == "__main__":
172+
parser = argparse.ArgumentParser()
173+
subparsers = parser.add_subparsers()
174+
175+
# pb2txt
176+
subparser = subparsers.add_parser("pb2txt", help="convert pb to pbtxt")
177+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
178+
subparser.add_argument("--output", dest="output_path", help="output pbtxt path")
179+
subparser.set_defaults(func=main.convert_pb_to_pbtxt)
180+
181+
# pb2summary
182+
subparser = subparsers.add_parser("pb2summary", help="create summary from pb")
183+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
184+
subparser.add_argument("--output", dest="output_dir", help="output summary directory")
185+
subparser.add_argument("--tb", dest="start_tensorboard", action="store_true", default=False,
186+
help="open with tensorboard")
187+
subparser.add_argument("--port", type=int, help="tensorboard port")
188+
subparser.set_defaults(func=main.convert_pb_to_summary)
189+
190+
# io
191+
subparser = subparsers.add_parser("io", help="get input nodes for graph, guess output nodes")
192+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
193+
subparser.set_defaults(func=main.get_graph_io_nodes)
194+
195+
# stat
196+
subparser = subparsers.add_parser("stat", help="print stat")
197+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
198+
subparser.set_defaults(func=main.print_graph_stat)
199+
200+
# extract
201+
subparser = subparsers.add_parser("extract", help="extract sub-graph")
202+
subparser.add_argument("--input", dest="input_path", required=True, help="input pb path")
203+
subparser.add_argument("--output", dest="output_path", help="output pb path")
204+
subparser.add_argument("--dest_nodes", help="dest nodes")
205+
subparser.set_defaults(func=main.extract_sub_graph)
206+
207+
if len(sys.argv) <= 2:
208+
parser.print_help()
209+
sys.exit()
210+
211+
(args, unknown) = parser.parse_known_args()
212+
213+
func = args.func
214+
del args.func
215+
216+
args = dict(filter(lambda x: x[1], vars(args).items()))
217+
func(**args)

0 commit comments

Comments
 (0)