Skip to content

Commit efb4bce

Browse files
authored
Merge pull request #409 from lucienwang1009/bugs
Load saved model with redundant inputs
2 parents c6e55d6 + 6a7cc79 commit efb4bce

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed
Binary file not shown.

tests/run_pretrained_models.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ regression-saved-model:
2727
outputs:
2828
- pred:0
2929

30+
saved_model_with_redundant_inputs:
31+
model: tests/models/saved_model_with_redundant_inputs
32+
model_type: saved_model
33+
input_get: get_ramp
34+
inputs:
35+
"Placeholder:0": [1, 10]
36+
outputs:
37+
- Add:0
38+
3039
benchtf-fc:
3140
model: tests/models/fc-layers/frozen.pb
3241
input_get: get_ramp

tf2onnx/loader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11+
import logging
12+
1113
import tensorflow as tf
1214
from tensorflow.python.framework.graph_util import convert_variables_to_constants
1315

16+
from tf2onnx import utils
17+
18+
logging.basicConfig(level=logging.INFO)
19+
log = logging.getLogger("loader")
1420

1521
# pylint: disable=unused-argument
1622

@@ -87,6 +93,15 @@ def from_saved_model(model_path, input_names, output_names):
8793
for _, output_tensor in sorted(outputs_tensor_info.items()):
8894
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
8995
frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
96+
frozen_inputs = []
97+
# get inputs in frozen graph
98+
for n in frozen_graph.node:
99+
for inp, _ in inputs.items():
100+
if utils.node_name(inp) == n.name:
101+
frozen_inputs.append(inp)
102+
deleted_inputs = list(set(inputs.keys()) - set(frozen_inputs))
103+
if deleted_inputs:
104+
log.warning("inputs [%s] is not in frozen graph, delete them", ",".join(deleted_inputs))
90105
# clean up
91106
tf.reset_default_graph()
92-
return frozen_graph, inputs.keys(), outputs.keys()
107+
return frozen_graph, frozen_inputs, outputs.keys()

0 commit comments

Comments
 (0)