Skip to content

Commit 229985e

Browse files
Add getting started sample code (#1502)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f1617ac commit 229985e

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

examples/getting_started.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""
4+
This example shows how to convert tf functions and keras models using the Python API.
5+
It also demonstrates converting saved_models from the command line.
6+
"""
7+
8+
import tensorflow as tf
9+
import tf2onnx
10+
import numpy as np
11+
import onnxruntime as ort
12+
import os
13+
14+
##################### tf function #####################
15+
16+
@tf.function
17+
def f(a, b):
18+
return a + b
19+
20+
input_signature = [tf.TensorSpec([2, 3], tf.float32), tf.TensorSpec([2, 3], tf.float32)]
21+
onnx_model, _ = tf2onnx.convert.from_function(f, input_signature, opset=13)
22+
23+
a_val = np.ones([2, 3], np.float32)
24+
b_val = np.zeros([2, 3], np.float32)
25+
26+
print("Tensorflow result")
27+
print(f(a_val, b_val).numpy())
28+
29+
print("ORT result")
30+
sess = ort.InferenceSession(onnx_model.SerializeToString())
31+
res = sess.run(None, {'a': a_val, 'b': b_val})
32+
print(res[0])
33+
34+
35+
##################### Keras Model #####################
36+
37+
model = tf.keras.Sequential()
38+
model.add(tf.keras.layers.Dense(4, activation="relu"))
39+
40+
input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')]
41+
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)
42+
43+
x_val = np.ones((3, 3), np.float32)
44+
45+
print("Keras result")
46+
print(model(x_val).numpy())
47+
48+
print("ORT result")
49+
sess = ort.InferenceSession(onnx_model.SerializeToString())
50+
res = sess.run(None, {'x': x_val})
51+
print(res[0])
52+
53+
54+
##################### Saved Model #####################
55+
56+
model.save("savedmodel")
57+
os.system("python -m tf2onnx.convert --saved-model savedmodel --output model.onnx --opset 13")
58+
59+
print("ORT result")
60+
sess = ort.InferenceSession("model.onnx")
61+
res = sess.run(None, {'dense_input:0': x_val})
62+
print(res[0])
63+
64+
print("Conversion succeeded")

tests/test_example.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def run_example(self, name, expected=None):
1919
"..", "examples", name)
2020
if not os.path.exists(full):
2121
raise FileNotFoundError(full)
22-
proc = subprocess.run(('python %s' % full).split(),
22+
proc = subprocess.run(['python', full],
2323
capture_output=True, check=True)
2424
self.assertEqual(0, proc.returncode)
2525
out = proc.stdout.decode('ascii')
@@ -51,6 +51,14 @@ def test_end2end_tfhub(self):
5151
"Optimizing ONNX model",
5252
"Using opset <onnx, 12>"])
5353

54+
@check_tf_min_version("2.3", "use tf.keras")
55+
@check_opset_min_version(13)
56+
@check_opset_max_version(13)
57+
def test_getting_started(self):
58+
self.run_example(
59+
"getting_started.py",
60+
expected=["Conversion succeeded"])
61+
5462

5563
if __name__ == '__main__':
5664
unittest.main()

0 commit comments

Comments
 (0)