@@ -38,6 +38,20 @@ def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=T
38
38
return frozen_graph
39
39
40
40
41
+ def remove_redundant_inputs (frozen_graph , input_names ):
42
+ """Remove redundant inputs not in frozen graph."""
43
+ frozen_inputs = []
44
+ # get inputs in frozen graph
45
+ for n in frozen_graph .node :
46
+ for inp in input_names :
47
+ if utils .node_name (inp ) == n .name :
48
+ frozen_inputs .append (inp )
49
+ deleted_inputs = list (set (input_names ) - set (frozen_inputs ))
50
+ if deleted_inputs :
51
+ log .warning ("inputs [%s] is not in frozen graph, delete them" , "," .join (deleted_inputs ))
52
+ return frozen_inputs
53
+
54
+
41
55
def from_graphdef (model_path , input_names , output_names ):
42
56
"""Load tensorflow graph from graphdef."""
43
57
# make sure we start with clean default graph
@@ -48,6 +62,7 @@ def from_graphdef(model_path, input_names, output_names):
48
62
graph_def .ParseFromString (f .read ())
49
63
tf .import_graph_def (graph_def , name = '' )
50
64
frozen_graph = freeze_session (sess , output_names = output_names )
65
+ input_names = remove_redundant_inputs (frozen_graph , input_names )
51
66
# clean up
52
67
tf .reset_default_graph ()
53
68
return frozen_graph , input_names , output_names
@@ -63,6 +78,7 @@ def from_checkpoint(model_path, input_names, output_names):
63
78
# restore from model_path minus the ".meta"
64
79
saver .restore (sess , model_path [:- 5 ])
65
80
frozen_graph = freeze_session (sess , output_names = output_names )
81
+ input_names = remove_redundant_inputs (frozen_graph , input_names )
66
82
# clean up
67
83
tf .reset_default_graph ()
68
84
return frozen_graph , input_names , output_names
@@ -93,15 +109,9 @@ def from_saved_model(model_path, input_names, output_names):
93
109
for _ , output_tensor in sorted (outputs_tensor_info .items ()):
94
110
outputs [output_tensor .name ] = sess .graph .get_tensor_by_name (output_tensor .name )
95
111
frozen_graph = freeze_session (sess , output_names = list (outputs .keys ()))
96
- frozen_inputs = []
97
- # get inputs in frozen graph
98
- for n in frozen_graph .node :
99
- for inp , _ in inputs .items ():
100
- if utils .node_name (inp ) == n .name :
101
- frozen_inputs .append (inp )
102
- deleted_inputs = list (set (inputs .keys ()) - set (frozen_inputs ))
103
- if deleted_inputs :
104
- log .warning ("inputs [%s] is not in frozen graph, delete them" , "," .join (deleted_inputs ))
112
+ if input_names is None :
113
+ input_names = inputs .keys ()
114
+ input_names = remove_redundant_inputs (frozen_graph , input_names )
105
115
# clean up
106
116
tf .reset_default_graph ()
107
- return frozen_graph , frozen_inputs , outputs .keys ()
117
+ return frozen_graph , input_names , outputs .keys ()
0 commit comments