Skip to content

Commit 87f5b0a

Browse files
authored
Merge pull request #354 from onnx/gs/more_tf_formats
support for checkpoint and saved_model format
2 parents 081e295 + 03e752a commit 87f5b0a

16 files changed

+283
-91
lines changed

README.md

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ onnxruntime (only avaliable on linux):
3333

3434
```pip install onnxruntime```
3535

36-
For caffe2, follow the instructions here:
36+
For pytorch/caffe2, follow the instructions here:
3737

38-
```https://caffe2.ai/```
38+
```https://pytorch.org/```
3939

4040

41-
We tested with caffe2 and onnxruntime and unit tests are passing for those.
41+
We tested with pytorch/caffe2 and onnxruntime and unit tests are passing for those.
4242

4343
## Supported Tensorflow and Python Versions
44-
We tested with tensorflow 1.5-1.11 and anaconda **3.5,3.6**.
44+
We tested with tensorflow 1.5-1.12 and anaconda **3.5,3.6**.
4545

4646
# Installation
4747
## From Pypi
@@ -64,13 +64,17 @@ python setup.py bdist_wheel
6464

6565
# Usage
6666

67-
To convert a TensorFlow model, tf2onnx expects a ```frozen TensorFlow graph``` and the user needs to specify inputs and outputs for the graph by passing the input and output
67+
To convert a TensorFlow model, tf2onnx prefers a ```frozen TensorFlow graph``` and the user needs to specify inputs and outputs for the graph by passing the input and output
6868
names with ```--inputs INPUTS``` and ```--outputs OUTPUTS```.
6969

7070
```
71-
python -m tf2onnx.convert --input SOURCE_FROZEN_GRAPH_PB
72-
--inputs SOURCE_GRAPH_INPUTS
73-
--outputs SOURCE_GRAPH_OUTPUS
71+
python -m tf2onnx.convert
72+
--input SOURCE_GRAPHDEF_PB
73+
--graphdef SOURCE_GRAPHDEF_PB
74+
--checkpoint SOURCE_CHECKPOINT
75+
--saved-model SOURCE_SAVED_MODEL
76+
[--inputs GRAPH_INPUTS]
77+
[--outputs GRAPH_OUTPUS]
7478
[--inputs-as-nchw inputs_provided_as_nchw]
7579
[--target TARGET]
7680
[--output TARGET_ONNX_GRAPH]
@@ -83,21 +87,26 @@ python -m tf2onnx.convert --input SOURCE_FROZEN_GRAPH_PB
8387
```
8488

8589
## Parameters
86-
### input
87-
frozen TensorFlow graph, which can be created with the [freeze graph tool](#freeze_graph).
88-
### output
90+
### --input or --graphdef
91+
TensorFlow model as graphdef file. If not already frozen we'll try to freeze the model.
92+
More information about freezing can be found here: [freeze graph tool](#freeze_graph).
93+
### --checkpoint
94+
TensorFlow model as checkpoint. We expect the path to the .meta file. tf2onnx will try to freeze the graph.
95+
### --saved-model
96+
TensorFlow model as saved_model. We expect the path to the saved_model directory. tf2onnx will try to freeze the graph.
97+
### --output
8998
the target onnx file path.
90-
### inputs, outputs
91-
Tensorflow graph's input/output names, which can be found with [summarize graph tool](#summarize_graph). Those names typically end on ```:0```, for example ```--inputs input0:0,input1:0```
92-
### inputs-as-nchw
99+
### --inputs, --outputs
100+
Tensorflow model's input/output names, which can be found with [summarize graph tool](#summarize_graph). Those names typically end on ```:0```, for example ```--inputs input0:0,input1:0```. inputs and outputs are ***not*** needed for models in saved-model format.
101+
### --inputs-as-nchw
93102
By default we preserve the image format of inputs (nchw or nhwc) as given in the TensorFlow model. If your hosts (for example windows) native format nchw and the model is written for nhwc, ```--inputs-as-nchw``` tensorflow-onnx will transpose the input. Doing so is convinient for the application and the converter in many cases can optimize the transpose away. For example ```--inputs input0:0,input1:0 --inputs-as-nchw input0:0``` assumes that images are passed into ```input0:0``` as nchw while the TensorFlow model given uses nhwc.
94-
### target
103+
### --target
95104
Some runtimes need workarounds, for example they don't support all types given in the onnx spec. We'll workaround it in some cases by generating a different graph. Those workarounds are activated with ```--target TARGET```.
96-
### opset
105+
### --opset
97106
by default we uses the newest opset 7 to generate the graph. By specifieing ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
98-
### custom-ops
107+
### --custom-ops
99108
the runtime may support custom ops that are not defined in onnx. A user can asked the converter to map to custom ops by listing them with the --custom-ops option. Tensorflow ops listed here will be mapped to a custom op with the same name as the tensorflow op but in the onnx domain ai.onnx.converters.tensorflow. For example: ```--custom-ops Print``` will insert a op ```Print``` in the onnx domain ```ai.onnx.converters.tensorflow``` into the graph. We also support a python api for custom ops documented later in this readme.
100-
### fold_const
109+
### --fold_const
101110
when set, TensorFlow fold_constants transformation will be applied before conversion. This will benefit features including Transpose optimization (e.g. Transpose operations introduced during tf-graph-to-onnx-graph conversion will be removed), and RNN unit conversion (for example LSTM). Older TensorFlow version might run into issues with this option depending on the model.
102111

103112

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_checkpoint_path: "model"
2+
all_model_checkpoint_paths: "model"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
�[>�V�?
134 Bytes
Binary file not shown.
19.1 KB
Binary file not shown.
350 Bytes
Binary file not shown.
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
�[>�V�?
Binary file not shown.

tests/run_pretrained_models.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,18 @@
1616
import traceback
1717
import zipfile
1818

19+
import PIL.Image
1920
import numpy as np
2021
import requests
2122
import six
2223
import tensorflow as tf
23-
from tensorflow.core.framework import graph_pb2
24-
from tensorflow.python.framework.graph_util import convert_variables_to_constants
2524
# contrib ops are registered only when the module is imported, the following import statement is needed,
2625
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
2726
import tensorflow.contrib.rnn # pylint: disable=unused-import
2827
import yaml
29-
import PIL.Image
3028

3129
import tf2onnx
30+
from tf2onnx import loader
3231
from tf2onnx import utils
3332
from tf2onnx.graph import GraphUtil
3433
from tf2onnx.tfonnx import process_tf_graph
@@ -74,23 +73,6 @@ def get_ramp(shape):
7473
}
7574

7675

77-
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
78-
"""Freezes the state of a session into a pruned computation graph."""
79-
output_names = [i.replace(":0", "") for i in output_names]
80-
graph = sess.graph
81-
with graph.as_default():
82-
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
83-
output_names = output_names or []
84-
output_names += [v.op.name for v in tf.global_variables()]
85-
input_graph_def = graph.as_graph_def()
86-
if clear_devices:
87-
for node in input_graph_def.node:
88-
node.device = ""
89-
frozen_graph = convert_variables_to_constants(sess, input_graph_def,
90-
output_names, freeze_var_names)
91-
return frozen_graph
92-
93-
9476
class Test(object):
9577
"""Main Test class."""
9678

@@ -236,44 +218,14 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
236218
dir_name = os.path.dirname(self.local)
237219
print("\tdownloaded", model_path)
238220

221+
inputs = list(self.input_names.keys())
222+
outputs = self.output_names
239223
if self.model_type in ["checkpoint"]:
240-
#
241-
# if the input model is a checkpoint, convert it to a frozen model
242-
saver = tf.train.import_meta_graph(model_path)
243-
with tf.Session() as sess:
244-
saver.restore(sess, model_path[:-5])
245-
frozen_graph = freeze_session(sess, output_names=self.output_names)
246-
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
247-
model_path = os.path.join(dir_name, "frozen.pb")
224+
graph_def, inputs, outputs = loader.from_checkpoint(model_path, inputs, outputs)
248225
elif self.model_type in ["saved_model"]:
249-
try:
250-
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
251-
get_signature_def = lambda meta_graph_def, k: \
252-
signature_def_utils.get_signature_def_by_key(meta_graph_def, k)
253-
except ImportError:
254-
# TF1.12 changed the api
255-
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]
256-
257-
# saved_model format - convert to checkpoint
258-
with tf.Session() as sess:
259-
meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
260-
inputs = {}
261-
outputs = {}
262-
for k in meta_graph_def.signature_def.keys():
263-
inputs_tensor_info = get_signature_def(meta_graph_def, k).inputs
264-
for _, input_tensor in sorted(inputs_tensor_info.items()):
265-
inputs[input_tensor.name] = sess.graph.get_tensor_by_name(input_tensor.name)
266-
outputs_tensor_info = get_signature_def(meta_graph_def, k).outputs
267-
for _, output_tensor in sorted(outputs_tensor_info.items()):
268-
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
269-
# freeze uses the node name derived from output:0 so only pass in output:0;
270-
# it will provide all outputs of that node.
271-
for o in list(outputs.keys()):
272-
if not o.endswith(":0"):
273-
del outputs[o]
274-
frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
275-
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
276-
model_path = os.path.join(dir_name, "frozen.pb")
226+
graph_def, inputs, outputs = loader.from_saved_model(model_path, inputs, outputs)
227+
else:
228+
graph_def, inputs, outputs = loader.from_graphdef(model_path, inputs, outputs)
277229

278230
# create the input data
279231
inputs = {}
@@ -285,10 +237,6 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
285237
if self.more_inputs:
286238
for k, v in self.more_inputs.items():
287239
inputs[k] = v
288-
tf.reset_default_graph()
289-
graph_def = graph_pb2.GraphDef()
290-
with open(model_path, "rb") as f:
291-
graph_def.ParseFromString(f.read())
292240

293241
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
294242
shape_override = {}

0 commit comments

Comments
 (0)