|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" |
| 5 | +Save pre-trained model. |
| 6 | +""" |
| 7 | +import tensorflow as tf |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +# pylint: disable=redefined-outer-name,reimported |
| 11 | + |
| 12 | +def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"): |
| 13 | + """Save pretrained model and config""" |
| 14 | + try: |
| 15 | + import os |
| 16 | + import sys |
| 17 | + import tensorflow as tf |
| 18 | + import subprocess |
| 19 | + to_onnx_path = "{}/to_onnx".format(out_dir) |
| 20 | + if not os.path.isdir(to_onnx_path): |
| 21 | + os.makedirs(to_onnx_path) |
| 22 | + saved_model = "{}/saved_model".format(to_onnx_path) |
| 23 | + inputs_path = "{}/inputs.npy".format(to_onnx_path) |
| 24 | + pretrained_model_yaml_path = "{}/pretrained.yaml".format(to_onnx_path) |
| 25 | + envars_path = "{}/environment.txt".format(to_onnx_path) |
| 26 | + pip_requirement_path = "{}/requirements.txt".format(to_onnx_path) |
| 27 | + |
| 28 | + print("===============Save Saved Model========================") |
| 29 | + if os.path.exists(saved_model): |
| 30 | + print("{} already exists, SKIP".format(saved_model)) |
| 31 | + return |
| 32 | + |
| 33 | + print("Save tf version, python version and installed packages") |
| 34 | + tf_version = tf.__version__ |
| 35 | + py_version = sys.version |
| 36 | + pip_packages = subprocess.check_output([sys.executable, "-m", "pip", "freeze", "--all"]) |
| 37 | + pip_packages = pip_packages.decode("UTF-8") |
| 38 | + with open(envars_path, "w") as fp: |
| 39 | + fp.write(tf_version + os.linesep) |
| 40 | + fp.write(py_version) |
| 41 | + with open(pip_requirement_path, "w") as fp: |
| 42 | + fp.write(pip_packages) |
| 43 | + |
| 44 | + print("Save model for tf2onnx: {}".format(to_onnx_path)) |
| 45 | + # save inputs |
| 46 | + inputs = {} |
| 47 | + for inp, value in feeds.items(): |
| 48 | + if isinstance(inp, str): |
| 49 | + inputs[inp] = value |
| 50 | + else: |
| 51 | + inputs[inp.name] = value |
| 52 | + np.save(inputs_path, inputs) |
| 53 | + print("Saved inputs to {}".format(inputs_path)) |
| 54 | + |
| 55 | + # save graph and weights |
| 56 | + from tensorflow.saved_model import simple_save |
| 57 | + simple_save(sess, saved_model, |
| 58 | + {n: i for n, i in zip(inputs.keys(), feeds.keys())}, |
| 59 | + {op.name: op for op in outputs}) |
| 60 | + print("Saved model to {}".format(saved_model)) |
| 61 | + |
| 62 | + # generate config |
| 63 | + pretrained_model_yaml = ''' |
| 64 | +{}: |
| 65 | + model: ./saved_model |
| 66 | + model_type: saved_model |
| 67 | + input_get: get_ramp |
| 68 | +'''.format(model_name) |
| 69 | + pretrained_model_yaml += " inputs:\n" |
| 70 | + for inp, _ in inputs.items(): |
| 71 | + pretrained_model_yaml += \ |
| 72 | + " \"{input}\": np.array(np.load(\"./inputs.npy\")[()][\"{input}\"])\n".format(input=inp) |
| 73 | + outputs = [op.name for op in outputs] |
| 74 | + pretrained_model_yaml += " outputs:\n" |
| 75 | + for out in outputs: |
| 76 | + pretrained_model_yaml += " - {}\n".format(out) |
| 77 | + with open(pretrained_model_yaml_path, "w") as f: |
| 78 | + f.write(pretrained_model_yaml) |
| 79 | + print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path)) |
| 80 | + print("=========================================================") |
| 81 | + except Exception as ex: # pylint: disable=broad-except |
| 82 | + print("Error: {}".format(ex)) |
| 83 | + |
| 84 | + |
| 85 | +def test(): |
| 86 | + """Test sample.""" |
| 87 | + x_val = np.random.rand(5, 20).astype(np.float32) |
| 88 | + y_val = np.random.rand(20, 10).astype(np.float32) |
| 89 | + x = tf.placeholder(tf.float32, x_val.shape, name="x") |
| 90 | + y = tf.placeholder(tf.float32, y_val.shape, name="y") |
| 91 | + z = tf.matmul(x, y) |
| 92 | + w = tf.get_variable("weight", [5, 10], dtype=tf.float32) |
| 93 | + init = tf.global_variables_initializer() |
| 94 | + outputs = [z + w] |
| 95 | + feeds = {x: x_val, y: y_val} |
| 96 | + with tf.Session() as sess: |
| 97 | + sess.run(init) |
| 98 | + sess.run(outputs, feeds) |
| 99 | + # NOTE: NOT override the saved model, so put below snippet after testing the BEST model. |
| 100 | + # if you perform testing several times. |
| 101 | + save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test") |
| 102 | + |
| 103 | + |
| 104 | +if __name__ == "__main__": |
| 105 | + test() |
0 commit comments