Skip to content

Commit 7328183

Browse files
authored
Merge pull request #71 from onnx/gs/lstm
document the python api
2 parents 06339ac + 4a7e895 commit 7328183

File tree

4 files changed

+143
-17
lines changed

4 files changed

+143
-17
lines changed

README.md

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ python -m tf2onnx.convert --input SOURCE_FROZEN_GRAPH_PB\
5858
[--target TARGET]\
5959
[--continue_on_error]\
6060
[--verbose]\
61+
[--custom-ops list-of-custom-ops]\
6162
[--opset OPSET]
6263
```
6364

6465
Parameters:
65-
- input: frozen TensorFlow graph, which can be got with [freeze graph tool](#freeze_graph).
66+
- input: frozen TensorFlow graph, which can be created with the [freeze graph tool](#freeze_graph).
6667
- output: the target onnx file path.
67-
- inputs/outputs: Tensorflow graph's input/output names, which can be got with [summarize graph tool](#summarize_graph).
68-
- target: There are different onnx versions and workarounds for runtimes that can be set with ```--target TARGET```. The default is onnx-1.1 and caffe2 which generates a graph
69-
that can be executed on a onnx-1.0/onnx-1.1 runtime, like caffe2 and winml.
68+
- inputs/outputs: Tensorflow graph's input/output names, which can be found with [summarize graph tool](#summarize_graph).
69+
- target: There are different onnx versions and workarounds for runtimes that can be set with ```--target TARGET```.
70+
- opset: by default we uses the newest opset installed with the onnx package (for example onnx-1.2.2 would have opset 7). By specifieing ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
71+
- custom-ops: the runtime may support custom ops that are not defined in onnx. A user can asked the converter to map to custom ops by listing them with the --custom-ops option. Tensorflow ops listed here will be mapped to a custom op of the same name as the tensorflow op but in the onnx domain ai.onnx.converters.tensorflow. For example: ```--custom-ops Print``` will insert a op ```Print``` in the onnx domain ```ai.onnx.converters.tensorflow``` into the graph. We also support a python api for custom ops documented later in this readme.
7072

7173
Usage example (run following commands in tensorflow-onnx root directory):
7274
```
@@ -78,7 +80,7 @@ python -m tf2onnx.convert\
7880
--verbose
7981
```
8082

81-
## <a name="summarize_graph"></a>Tool to Get Graph Inputs & Outputs
83+
## <a name="summarize_graph"></a>Tool to get Graph Inputs & Outputs
8284

8385
To find the inputs and outputs for the TensorFlow graph the model developer will know or you can consult TensorFlow's [summarize_graph](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) tool, for example:
8486
```
@@ -98,7 +100,6 @@ python -m tensorflow.python.tools.freeze_graph \
98100
--output_graph=tests/models/fc-layers/frozen.pb
99101
```
100102

101-
102103
# Testing
103104
There are 2 types of tests.
104105

@@ -120,14 +121,90 @@ optional arguments:
120121
--config yaml config file
121122
--verbose verbose output
122123
--opset OPSET target opset to use
124+
--perf csv-file capture performance numbers or tensorflow and onnx runtime
123125
--debug dump generated graph with shape info
124126
```
125-
```run_pretrained_models.py``` will run the TensorFlow model, captures the TensorFlow output and runs the same test against the specified ONNX backend after converting the model. The only practical backend to use at this time is Caffe2, and you need to install Caffe2 for this to work.
127+
```run_pretrained_models.py``` will run the TensorFlow model, captures the TensorFlow output and runs the same test against the specified ONNX backend after converting the model. The only practical backend to use at this time is Caffe2, and you need to install Caffe2 for this to work.
128+
If the option ```--perf csv-file``` is specified, we'll capture the eval runtime for tensorflow and onnx runtime and write the result into the given csv file.
126129

127130
You call it for example with:
128131
```
129-
python tests/run_pretrained_models.py --backend caffe2 --config tests/run_pretrained_models.yaml
132+
python tests/run_pretrained_models.py --backend caffe2 --config tests/run_pretrained_models.yaml --perf perf.csv
133+
```
134+
135+
# Using the Python Api
136+
## Tensorflow to onnx conversion
137+
In some cases it will be usefull to convert the models from tensorflow to onnx from a python script. You can use the following api:
138+
```
139+
import tf2onnx
140+
141+
tf2onnx.tfonnx.process_tf_graph(tf_graph,
142+
continue_on_error=False, verbose=False, target=None,
143+
opset=None, custom_op_handlers=None,
144+
custom_rewriter=None, extra_opset=None):
145+
"""Convert tensorflow graph to onnx graph.
146+
Args:
147+
tf_graph: tensorflow graph
148+
continue_on_error: if an op can't be processed (aka there is no mapping), continue
149+
verbose: print summary stats
150+
target: list of workarounds applied to help certain platforms
151+
opset: the opset to be used (int, default is latest)
152+
custom_op_handlers: dictionary of custom ops handlers
153+
custom_rewriter: list of custom graph rewriters
154+
Return:
155+
onnx graph
156+
"""
157+
```
158+
For example in [examples/call_coverter_via_python.py]():
159+
```
160+
import tensorflow as tf
161+
import tf2onnx
162+
163+
with tf.Session() as sess:
164+
x = tf.placeholder(tf.float32, [2, 3], name="input")
165+
x_ = tf.add(x, x)
166+
_ = tf.identity(x_, name="output")
167+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph)
168+
model_proto = onnx_graph.make_model("test",
169+
["input:0"], ["output:0"])
170+
with open("/tmp/model.onnx", "wb") as f:
171+
f.write(model_proto.SerializeToString())
172+
```
173+
## Creating custom op mappings from python
174+
For complex custom ops that require graph rewrites or input / attribute rewrites using the python interface to insert a custom op will be the eaiest way to accomplish the task.
175+
A dictionary of name->custom_op_handler can be passed to tf2onnx.tfonnx.process_tf_graph. If the op name is found in the graph the handler will have access to all internal structures and can rewrite that is needed. For example [examples/custom_op_via_python.py]():
176+
```
177+
import tensorflow as tf
178+
import tf2onnx
179+
from onnx import helper
180+
181+
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
182+
183+
184+
def print_handler(ctx, node, name, args):
185+
# replace tf.Print() with Identity
186+
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
187+
# becomes:
188+
# T output = Identity(T Input)
189+
node.type = "Identity"
190+
node.domain = _TENSORFLOW_DOMAIN
191+
del node.input[1:]
192+
return node
193+
194+
195+
with tf.Session() as sess:
196+
x = tf.placeholder(tf.float32, [2, 3], name="input")
197+
x_ = tf.add(x, x)
198+
x_ = tf.Print(x, [x], "hello")
199+
_ = tf.identity(x_, name="output")
200+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
201+
custom_op_handlers={"Print": print_handler},
202+
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)])
203+
model_proto = onnx_graph.make_model("test", ["input:0"], ["output:0"])
204+
with open("/tmp/model.onnx", "wb") as f:
205+
f.write(model_proto.SerializeToString())
130206
```
207+
131208
# How tf2onnx works
132209
While the protobuf format of ONNX is not all that different than onnx, mileage will vary because TensorFlow supports 4x the ops compared to the current version of ONNX.
133210
The converter needs to take care of a few things:

examples/call_coverter_via_python.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
A simple example how to call tensorflow-onnx via python.
3+
"""
4+
5+
import tensorflow as tf
6+
import tf2onnx
7+
8+
with tf.Session() as sess:
9+
x = tf.placeholder(tf.float32, [2, 3], name="input")
10+
x_ = tf.add(x, x)
11+
_ = tf.identity(x_, name="output")
12+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph)
13+
model_proto = onnx_graph.make_model("test", ["input:0"], ["output:0"])
14+
with open("/tmp/model.onnx", "wb") as f:
15+
f.write(model_proto.SerializeToString())

examples/custom_op_via_python.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
A simple example how to map a custom op in python.
3+
"""
4+
import tensorflow as tf
5+
import tf2onnx
6+
from onnx import helper
7+
8+
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
9+
10+
11+
def print_handler(ctx, node, name, args):
12+
# replace tf.Print() with Identity
13+
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
14+
# becomes:
15+
# T output = Identity(T Input)
16+
node.type = "Identity"
17+
node.domain = _TENSORFLOW_DOMAIN
18+
del node.input[1:]
19+
return node
20+
21+
22+
with tf.Session() as sess:
23+
x = tf.placeholder(tf.float32, [2, 3], name="input")
24+
x_ = tf.add(x, x)
25+
x_ = tf.Print(x, [x], "hello")
26+
_ = tf.identity(x_, name="output")
27+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
28+
custom_op_handlers={"Print": print_handler},
29+
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)])
30+
model_proto = onnx_graph.make_model("test", ["input:0"], ["output:0"])
31+
with open("/tmp/model.onnx", "wb") as f:
32+
f.write(model_proto.SerializeToString())

tf2onnx/tfonnx.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,22 +1279,24 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
12791279
if target_opset <= g.opset:
12801280
ops_mapping.update(op_map)
12811281

1282+
# apply custom ops on top of the assembled opset. We can either completment the opset
1283+
# or override existing ops with a custom op.
1284+
if custom_op_handlers is not None:
1285+
custom_opset = {k: [v, []] for k, v in custom_op_handlers.items()}
1286+
ops_mapping.update(custom_opset)
1287+
12821288
ops = g.get_nodes()
12831289
onnx_nodes = []
12841290
for node in ops:
12851291
op = node.type
12861292
map_info = ops_mapping.get(op)
12871293
if map_info is None:
1288-
custom_op = custom_op_handlers.get(op)
1289-
if custom_op is None:
1290-
if continue_on_error:
1291-
unmapped_op[op] += 1
1292-
onnx_nodes.append(node)
1293-
continue
1294-
else:
1295-
raise ValueError("tensorflow op " + op + " is not supported")
1294+
if continue_on_error:
1295+
unmapped_op[op] += 1
1296+
onnx_nodes.append(node)
1297+
continue
12961298
else:
1297-
map_info = (custom_op, [])
1299+
raise ValueError("tensorflow op " + op + " is not supported")
12981300
mapped_op[op] += 1
12991301
func, args = map_info
13001302
if args:

0 commit comments

Comments
 (0)