Skip to content

Commit 24b39f2

Browse files
add an utility to save pretrained model
1 parent 6d3f24b commit 24b39f2

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,9 @@ python tests/run_pretrained_models.py --backend onnxruntime --config tests/run_p
179179
### <a name="save_pretrained_model"></a>Tool to save pre-trained model
180180

181181
We provide an [utility](tools/save_pretrained_model.py) to save pre-trained model along with its config.
182-
Put `save_pretrained_model(sess, outputs, feed_inputs, save_dir, model_name)` in your last testing step and the pre-trained model and config will be saved under `save_dir/to_onnx`.
183-
Please refer to [tools/save_pretrained_model.py](tools/save_pretrained_model.py) for more information.
182+
Put `save_pretrained_model(sess, outputs, feed_inputs, save_dir, model_name)` in your last testing epoch and the pre-trained model and config will be saved under `save_dir/to_onnx`.
183+
Please refer to the example in [tools/save_pretrained_model.py](tools/save_pretrained_model.py) for more information.
184+
Note the minimum required Tensorflow version is r1.6.
184185

185186
# Using the Python API
186187
## TensorFlow to ONNX conversion

tools/save_pretrained_model.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,46 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Save pre-trained model.
6+
"""
17
import tensorflow as tf
28
import numpy as np
39

10+
# pylint: disable=redefined-outer-name,reimported
11+
412
def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"):
13+
"""Save pretrained model and config"""
514
try:
615
import os
16+
import sys
17+
import tensorflow as tf
18+
import subprocess
719
to_onnx_path = "{}/to_onnx".format(out_dir)
820
if not os.path.isdir(to_onnx_path):
921
os.makedirs(to_onnx_path)
1022
saved_model = "{}/saved_model".format(to_onnx_path)
1123
inputs_path = "{}/inputs.npy".format(to_onnx_path)
1224
pretrained_model_yaml_path = "{}/pretrained.yaml".format(to_onnx_path)
25+
envars_path = "{}/environment.txt".format(to_onnx_path)
26+
pip_requirement_path = "{}/requirements.txt".format(to_onnx_path)
27+
28+
print("===============Save Saved Model========================")
29+
if os.path.exists(saved_model):
30+
print("{} already exists, SKIP".format(saved_model))
31+
return
32+
33+
print("Save tf version, python version and installed packages")
34+
tf_version = tf.__version__
35+
py_version = sys.version
36+
pip_packages = subprocess.check_output([sys.executable, "-m", "pip", "freeze", "--all"])
37+
pip_packages = pip_packages.decode("UTF-8")
38+
with open(envars_path, "w") as fp:
39+
fp.write(tf_version + os.linesep)
40+
fp.write(py_version)
41+
with open(pip_requirement_path, "w") as fp:
42+
fp.write(pip_packages)
1343

14-
print("===============Save Frozen Graph========================")
1544
print("Save model for tf2onnx: {}".format(to_onnx_path))
1645
# save inputs
1746
inputs = {}
@@ -26,7 +55,7 @@ def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"
2655
# save graph and weights
2756
from tensorflow.saved_model import simple_save
2857
simple_save(sess, saved_model,
29-
{n: i for n,i in zip(inputs.keys(), feeds.keys())},
58+
{n: i for n, i in zip(inputs.keys(), feeds.keys())},
3059
{op.name: op for op in outputs})
3160
print("Saved model to {}".format(saved_model))
3261

@@ -39,9 +68,8 @@ def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"
3968
'''.format(model_name)
4069
pretrained_model_yaml += " inputs:\n"
4170
for inp, _ in inputs.items():
42-
pretrained_model_yaml += " \"{}\": np.array(np.load(\"./inputs.npy\")[()][\"{}\"])\n".format(
43-
inp, inp
44-
)
71+
pretrained_model_yaml += \
72+
" \"{input}\": np.array(np.load(\"./inputs.npy\")[()][\"{input}\"])\n".format(input=inp)
4573
outputs = [op.name for op in outputs]
4674
pretrained_model_yaml += " outputs:\n"
4775
for out in outputs:
@@ -50,11 +78,12 @@ def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"
5078
f.write(pretrained_model_yaml)
5179
print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path))
5280
print("=========================================================")
53-
except Exception as ex:
81+
except Exception as ex: # pylint: disable=broad-except
5482
print("Error: {}".format(ex))
5583

5684

5785
def test():
86+
"""Test sample."""
5887
x_val = np.random.rand(5, 20).astype(np.float32)
5988
y_val = np.random.rand(20, 10).astype(np.float32)
6089
x = tf.placeholder(tf.float32, x_val.shape, name="x")
@@ -66,8 +95,9 @@ def test():
6695
feeds = {x: x_val, y: y_val}
6796
with tf.Session() as sess:
6897
sess.run(init)
69-
out = sess.run(outputs, feeds)
70-
# NOTE: Put below snippet after the LAST testing step
98+
sess.run(outputs, feeds)
99+
# NOTE: NOT override the saved model, so put below snippet after testing the BEST model.
100+
# if you perform testing several times.
71101
save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test")
72102

73103

0 commit comments

Comments
 (0)