Skip to content

Commit 4a7e895

Browse files
committed
document python api
1 parent 1e9e2cc commit 4a7e895

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

README.md

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ python -m tensorflow.python.tools.freeze_graph \
100100
--output_graph=tests/models/fc-layers/frozen.pb
101101
```
102102

103-
# Using the converter via Python Api
104-
In some cases it might be desirable to use the converter from a python script.
105-
106103
# Testing
107104
There are 2 types of tests.
108105

@@ -173,10 +170,39 @@ with tf.Session() as sess:
173170
with open("/tmp/model.onnx", "wb") as f:
174171
f.write(model_proto.SerializeToString())
175172
```
176-
## Using custom ops from python
173+
## Creating custom op mappings from python
177174
For complex custom ops that require graph rewrites or input / attribute rewrites using the python interface to insert a custom op will be the eaiest way to accomplish the task.
178175
A dictionary of name->custom_op_handler can be passed to tf2onnx.tfonnx.process_tf_graph. If the op name is found in the graph the handler will have access to all internal structures and can rewrite that is needed. For example [examples/custom_op_via_python.py]():
179176
```
177+
import tensorflow as tf
178+
import tf2onnx
179+
from onnx import helper
180+
181+
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
182+
183+
184+
def print_handler(ctx, node, name, args):
185+
# replace tf.Print() with Identity
186+
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
187+
# becomes:
188+
# T output = Identity(T Input)
189+
node.type = "Identity"
190+
node.domain = _TENSORFLOW_DOMAIN
191+
del node.input[1:]
192+
return node
193+
194+
195+
with tf.Session() as sess:
196+
x = tf.placeholder(tf.float32, [2, 3], name="input")
197+
x_ = tf.add(x, x)
198+
x_ = tf.Print(x, [x], "hello")
199+
_ = tf.identity(x_, name="output")
200+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
201+
custom_op_handlers={"Print": print_handler},
202+
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)])
203+
model_proto = onnx_graph.make_model("test", ["input:0"], ["output:0"])
204+
with open("/tmp/model.onnx", "wb") as f:
205+
f.write(model_proto.SerializeToString())
180206
```
181207

182208
# How tf2onnx works

0 commit comments

Comments
 (0)