Skip to content

Commit 443e154

Browse files
authored
Merge pull request #454 from lucienwang1009/save_tf
add an utility to save pretrained model
2 parents 3e490a8 + 24b39f2 commit 443e154

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ You call it for example with:
176176
python tests/run_pretrained_models.py --backend onnxruntime --config tests/run_pretrained_models.yaml --perf perf.csv
177177
```
178178

179+
### <a name="save_pretrained_model"></a>Tool to save pre-trained model
180+
181+
We provide an [utility](tools/save_pretrained_model.py) to save pre-trained model along with its config.
182+
Put `save_pretrained_model(sess, outputs, feed_inputs, save_dir, model_name)` in your last testing epoch and the pre-trained model and config will be saved under `save_dir/to_onnx`.
183+
Please refer to the example in [tools/save_pretrained_model.py](tools/save_pretrained_model.py) for more information.
184+
Note the minimum required Tensorflow version is r1.6.
185+
179186
# Using the Python API
180187
## TensorFlow to ONNX conversion
181188
In some cases it will be useful to convert the models from TensorFlow to ONNX from a python script. You can use the following API:

tools/save_pretrained_model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Save pre-trained model.
6+
"""
7+
import tensorflow as tf
8+
import numpy as np
9+
10+
# pylint: disable=redefined-outer-name,reimported
11+
12+
def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"):
13+
"""Save pretrained model and config"""
14+
try:
15+
import os
16+
import sys
17+
import tensorflow as tf
18+
import subprocess
19+
to_onnx_path = "{}/to_onnx".format(out_dir)
20+
if not os.path.isdir(to_onnx_path):
21+
os.makedirs(to_onnx_path)
22+
saved_model = "{}/saved_model".format(to_onnx_path)
23+
inputs_path = "{}/inputs.npy".format(to_onnx_path)
24+
pretrained_model_yaml_path = "{}/pretrained.yaml".format(to_onnx_path)
25+
envars_path = "{}/environment.txt".format(to_onnx_path)
26+
pip_requirement_path = "{}/requirements.txt".format(to_onnx_path)
27+
28+
print("===============Save Saved Model========================")
29+
if os.path.exists(saved_model):
30+
print("{} already exists, SKIP".format(saved_model))
31+
return
32+
33+
print("Save tf version, python version and installed packages")
34+
tf_version = tf.__version__
35+
py_version = sys.version
36+
pip_packages = subprocess.check_output([sys.executable, "-m", "pip", "freeze", "--all"])
37+
pip_packages = pip_packages.decode("UTF-8")
38+
with open(envars_path, "w") as fp:
39+
fp.write(tf_version + os.linesep)
40+
fp.write(py_version)
41+
with open(pip_requirement_path, "w") as fp:
42+
fp.write(pip_packages)
43+
44+
print("Save model for tf2onnx: {}".format(to_onnx_path))
45+
# save inputs
46+
inputs = {}
47+
for inp, value in feeds.items():
48+
if isinstance(inp, str):
49+
inputs[inp] = value
50+
else:
51+
inputs[inp.name] = value
52+
np.save(inputs_path, inputs)
53+
print("Saved inputs to {}".format(inputs_path))
54+
55+
# save graph and weights
56+
from tensorflow.saved_model import simple_save
57+
simple_save(sess, saved_model,
58+
{n: i for n, i in zip(inputs.keys(), feeds.keys())},
59+
{op.name: op for op in outputs})
60+
print("Saved model to {}".format(saved_model))
61+
62+
# generate config
63+
pretrained_model_yaml = '''
64+
{}:
65+
model: ./saved_model
66+
model_type: saved_model
67+
input_get: get_ramp
68+
'''.format(model_name)
69+
pretrained_model_yaml += " inputs:\n"
70+
for inp, _ in inputs.items():
71+
pretrained_model_yaml += \
72+
" \"{input}\": np.array(np.load(\"./inputs.npy\")[()][\"{input}\"])\n".format(input=inp)
73+
outputs = [op.name for op in outputs]
74+
pretrained_model_yaml += " outputs:\n"
75+
for out in outputs:
76+
pretrained_model_yaml += " - {}\n".format(out)
77+
with open(pretrained_model_yaml_path, "w") as f:
78+
f.write(pretrained_model_yaml)
79+
print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path))
80+
print("=========================================================")
81+
except Exception as ex: # pylint: disable=broad-except
82+
print("Error: {}".format(ex))
83+
84+
85+
def test():
86+
"""Test sample."""
87+
x_val = np.random.rand(5, 20).astype(np.float32)
88+
y_val = np.random.rand(20, 10).astype(np.float32)
89+
x = tf.placeholder(tf.float32, x_val.shape, name="x")
90+
y = tf.placeholder(tf.float32, y_val.shape, name="y")
91+
z = tf.matmul(x, y)
92+
w = tf.get_variable("weight", [5, 10], dtype=tf.float32)
93+
init = tf.global_variables_initializer()
94+
outputs = [z + w]
95+
feeds = {x: x_val, y: y_val}
96+
with tf.Session() as sess:
97+
sess.run(init)
98+
sess.run(outputs, feeds)
99+
# NOTE: NOT override the saved model, so put below snippet after testing the BEST model.
100+
# if you perform testing several times.
101+
save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test")
102+
103+
104+
if __name__ == "__main__":
105+
test()

0 commit comments

Comments
 (0)