Skip to content

Commit b420b28

Browse files
committed
review feedback
1 parent 9bc75ec commit b420b28

File tree

5 files changed

+8
-13
lines changed

5 files changed

+8
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ For pytorch/caffe2, follow the instructions here:
4141
We tested with pytorch/caffe2 and onnxruntime and unit tests are passing for those.
4242

4343
## Supported Tensorflow and Python Versions
44-
We tested with tensorflow 1.5-1.13 and anaconda **3.5,3.6**.
44+
We tested with tensorflow 1.5-1.12 and anaconda **3.5,3.6**.
4545

4646
# Installation
4747
## From Pypi
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
model_checkpoint_path: "C:\\src\\tensorflow-onnx\\tests\\models\\regression\\checkpoint\\model"
2-
all_model_checkpoint_paths: "C:\\src\\tensorflow-onnx\\tests\\models\\regression\\checkpoint\\model"
1+
model_checkpoint_path: "model"
2+
all_model_checkpoint_paths: "model"

tf2onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def main():
108108

109109
# todo: consider to enable const folding by default?
110110
graph_def = tf_optimize(args.inputs, args.outputs, graph_def, args.fold_const)
111-
tf.reset_default_graph()
111+
112112
with tf.Graph().as_default() as tf_graph:
113113
tf.import_graph_def(graph_def, name='')
114114
with tf.Session(graph=tf_graph):

tf2onnx/loader.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
1919
"""Freezes the state of a session into a pruned computation graph."""
20-
output_names = [i.replace(":0", "") for i in output_names]
20+
output_names = [i.split(':')[:-1][0] for i in output_names]
2121
graph = sess.graph
2222
with graph.as_default():
2323
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
@@ -42,7 +42,7 @@ def from_graphdef(model_path, input_names, output_names):
4242
graph_def.ParseFromString(f.read())
4343
tf.import_graph_def(graph_def, name='')
4444
frozen_graph = freeze_session(sess, output_names=output_names)
45-
# clean up after us
45+
# clean up
4646
tf.reset_default_graph()
4747
return frozen_graph, input_names, output_names
4848

@@ -57,7 +57,7 @@ def from_checkpoint(model_path, input_names, output_names):
5757
# restore from model_path minus the ".meta"
5858
saver.restore(sess, model_path[:-5])
5959
frozen_graph = freeze_session(sess, output_names=output_names)
60-
# clean up after us
60+
# clean up
6161
tf.reset_default_graph()
6262
return frozen_graph, input_names, output_names
6363

@@ -86,12 +86,7 @@ def from_saved_model(model_path, input_names, output_names):
8686
outputs_tensor_info = get_signature_def(meta_graph_def, k).outputs
8787
for _, output_tensor in sorted(outputs_tensor_info.items()):
8888
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
89-
# freeze uses the node name derived from output:0 so only pass in output:0;
90-
# it will provide all outputs of that node.
91-
for o in list(outputs.keys()):
92-
if not o.endswith(":0"):
93-
del outputs[o]
9489
frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
95-
# clean up after us
90+
# clean up
9691
tf.reset_default_graph()
9792
return frozen_graph, inputs.keys(), outputs.keys()
File renamed without changes.

0 commit comments

Comments
 (0)