|
8 | 8 | from __future__ import print_function
|
9 | 9 | from __future__ import unicode_literals
|
10 | 10 |
|
| 11 | +import logging |
| 12 | + |
11 | 13 | import tensorflow as tf
|
12 | 14 | from tensorflow.python.framework.graph_util import convert_variables_to_constants
|
13 | 15 |
|
| 16 | +from tf2onnx import utils |
| 17 | + |
| 18 | +logging.basicConfig(level=logging.INFO) |
| 19 | +log = logging.getLogger("loader") |
14 | 20 |
|
15 | 21 | # pylint: disable=unused-argument
|
16 | 22 |
|
@@ -87,6 +93,15 @@ def from_saved_model(model_path, input_names, output_names):
|
87 | 93 | for _, output_tensor in sorted(outputs_tensor_info.items()):
|
88 | 94 | outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
|
89 | 95 | frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
|
| 96 | + frozen_inputs = [] |
| 97 | + # get inputs in frozen graph |
| 98 | + for n in frozen_graph.node: |
| 99 | + for inp, _ in inputs.items(): |
| 100 | + if utils.node_name(inp) == n.name: |
| 101 | + frozen_inputs.append(inp) |
| 102 | + deleted_inputs = list(set(inputs.keys()) - set(frozen_inputs)) |
| 103 | + if deleted_inputs: |
| 104 | + log.warning("inputs [%s] is not in frozen graph, delete them", ",".join(deleted_inputs)) |
90 | 105 | # clean up
|
91 | 106 | tf.reset_default_graph()
|
92 |
| - return frozen_graph, inputs.keys(), outputs.keys() |
| 107 | + return frozen_graph, frozen_inputs, outputs.keys() |
0 commit comments