Skip to content

Commit 505ce17

Browse files
authored
Merge pull request #456 from lucienwang1009/remove_useless_inputs
remove useless inputs
2 parents 443e154 + c8312bf commit 505ce17

File tree

3 files changed

+52
-17
lines changed

3 files changed

+52
-17
lines changed

tests/run_pretrained_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,20 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
222222
dir_name = os.path.dirname(self.local)
223223
print("\tdownloaded", model_path)
224224

225-
inputs = list(self.input_names.keys())
225+
input_names = list(self.input_names.keys())
226226
outputs = self.output_names
227227
if self.model_type in ["checkpoint"]:
228-
graph_def, inputs, outputs = loader.from_checkpoint(model_path, inputs, outputs)
228+
graph_def, input_names, outputs = loader.from_checkpoint(model_path, input_names, outputs)
229229
elif self.model_type in ["saved_model"]:
230-
graph_def, inputs, outputs = loader.from_saved_model(model_path, inputs, outputs)
230+
graph_def, input_names, outputs = loader.from_saved_model(model_path, input_names, outputs)
231231
else:
232-
graph_def, inputs, outputs = loader.from_graphdef(model_path, inputs, outputs)
232+
graph_def, input_names, outputs = loader.from_graphdef(model_path, input_names, outputs)
233233

234234
# create the input data
235235
inputs = {}
236236
for k, v in self.input_names.items():
237+
if k not in input_names:
238+
continue
237239
if isinstance(v, six.text_type) and v.startswith("np."):
238240
inputs[k] = eval(v) # pylint: disable=eval-used
239241
else:
@@ -312,10 +314,12 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
312314
print("\tResults: OK")
313315
return True
314316
except Exception as ex:
315-
print("\tResults: ", ex)
317+
tb = traceback.format_exc()
318+
print("\tResults", ex, tb)
316319

317320
except Exception as ex:
318-
print("\trun_onnx", "FAIL", ex)
321+
tb = traceback.format_exc()
322+
print("\trun_onnx", "FAIL", ex, tb)
319323

320324
return False
321325

@@ -399,7 +403,8 @@ def main():
399403
fold_const=args.fold_const)
400404
except Exception as ex:
401405
ret = None
402-
print(ex)
406+
tb = traceback.format_exc()
407+
print(ex, tb)
403408
finally:
404409
if not args.debug:
405410
utils.delete_directory(TEMP_DIR)

tests/run_pretrained_models.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,30 @@ saved_model_with_redundant_inputs:
3232
model_type: saved_model
3333
input_get: get_ramp
3434
inputs:
35+
"X:0": [1, 10]
3536
"Placeholder:0": [1, 10]
3637
outputs:
3738
- Add:0
3839

40+
graphdef_with_redundant_inputs:
41+
model: tests/models/regression/graphdef/frozen.pb
42+
input_get: get_ramp
43+
inputs:
44+
"X:0": [1, 10]
45+
"Placeholder:0": [1, 10]
46+
outputs:
47+
- Add:0
48+
49+
checkpoint_with_redundant_inputs:
50+
model: tests/models/regression/checkpoint/model.meta
51+
model_type: checkpoint
52+
input_get: get_ramp
53+
inputs:
54+
"X:0": [1]
55+
"Placeholder:0": [1, 10]
56+
outputs:
57+
- pred:0
58+
3959
benchtf-fc:
4060
model: tests/models/fc-layers/frozen.pb
4161
input_get: get_ramp

tf2onnx/loader.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=T
3838
return frozen_graph
3939

4040

41+
def remove_redundant_inputs(frozen_graph, input_names):
42+
"""Remove redundant inputs not in frozen graph."""
43+
frozen_inputs = []
44+
# get inputs in frozen graph
45+
for n in frozen_graph.node:
46+
for inp in input_names:
47+
if utils.node_name(inp) == n.name:
48+
frozen_inputs.append(inp)
49+
deleted_inputs = list(set(input_names) - set(frozen_inputs))
50+
if deleted_inputs:
51+
log.warning("inputs [%s] is not in frozen graph, delete them", ",".join(deleted_inputs))
52+
return frozen_inputs
53+
54+
4155
def from_graphdef(model_path, input_names, output_names):
4256
"""Load tensorflow graph from graphdef."""
4357
# make sure we start with clean default graph
@@ -48,6 +62,7 @@ def from_graphdef(model_path, input_names, output_names):
4862
graph_def.ParseFromString(f.read())
4963
tf.import_graph_def(graph_def, name='')
5064
frozen_graph = freeze_session(sess, output_names=output_names)
65+
input_names = remove_redundant_inputs(frozen_graph, input_names)
5166
# clean up
5267
tf.reset_default_graph()
5368
return frozen_graph, input_names, output_names
@@ -63,6 +78,7 @@ def from_checkpoint(model_path, input_names, output_names):
6378
# restore from model_path minus the ".meta"
6479
saver.restore(sess, model_path[:-5])
6580
frozen_graph = freeze_session(sess, output_names=output_names)
81+
input_names = remove_redundant_inputs(frozen_graph, input_names)
6682
# clean up
6783
tf.reset_default_graph()
6884
return frozen_graph, input_names, output_names
@@ -93,15 +109,9 @@ def from_saved_model(model_path, input_names, output_names):
93109
for _, output_tensor in sorted(outputs_tensor_info.items()):
94110
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
95111
frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
96-
frozen_inputs = []
97-
# get inputs in frozen graph
98-
for n in frozen_graph.node:
99-
for inp, _ in inputs.items():
100-
if utils.node_name(inp) == n.name:
101-
frozen_inputs.append(inp)
102-
deleted_inputs = list(set(inputs.keys()) - set(frozen_inputs))
103-
if deleted_inputs:
104-
log.warning("inputs [%s] is not in frozen graph, delete them", ",".join(deleted_inputs))
112+
if input_names is None:
113+
input_names = inputs.keys()
114+
input_names = remove_redundant_inputs(frozen_graph, input_names)
105115
# clean up
106116
tf.reset_default_graph()
107-
return frozen_graph, frozen_inputs, outputs.keys()
117+
return frozen_graph, input_names, outputs.keys()

0 commit comments

Comments
 (0)