Skip to content

Commit b1eccce

Browse files
authored
Merge pull request #1103 from xadupre/example
Add an end2end example with tf.keras
2 parents 51767eb + 1bcb143 commit b1eccce

File tree

3 files changed

+201
-0
lines changed

3 files changed

+201
-0
lines changed

examples/end2end_tfhub.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
This example retrieves a model from tensorflowhub.
3+
It is converted into ONNX. Predictions are compared to
4+
the predictions from tensorflow to check there is no
5+
discrepencies. Inferencing time is also compared between
6+
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
7+
"""
8+
from onnxruntime import InferenceSession
9+
import os
10+
import sys
11+
import subprocess
12+
import timeit
13+
import numpy as np
14+
import tensorflow as tf
15+
from tensorflow import keras
16+
from tensorflow.keras import Input
17+
try:
18+
import tensorflow_hub as tfhub
19+
except ImportError:
20+
# no tensorflow_hub
21+
print("tensorflow_hub not installed.")
22+
sys.exit(0)
23+
24+
########################################
25+
# Downloads the model.
26+
hub_layer = tfhub.KerasLayer(
27+
"https://tfhub.dev/google/efficientnet/b0/classification/1")
28+
model = keras.Sequential()
29+
model.add(Input(shape=(224, 224, 3), dtype=tf.float32))
30+
model.add(hub_layer)
31+
print(model.summary())
32+
33+
########################################
34+
# Saves the model.
35+
if not os.path.exists("efficientnetb0clas"):
36+
os.mkdir("efficientnetb0clas")
37+
tf.keras.models.save_model(model, "efficientnetb0clas")
38+
39+
input_names = [n.name for n in model.inputs]
40+
output_names = [n.name for n in model.outputs]
41+
print('inputs:', input_names)
42+
print('outputs:', output_names)
43+
44+
########################################
45+
# Testing the model.
46+
input = np.random.randn(2, 224, 224, 3).astype(np.float32)
47+
expected = model.predict(input)
48+
print(expected)
49+
50+
########################################
51+
# Run the command line.
52+
proc = subprocess.run(
53+
'python -m tf2onnx.convert --saved-model efficientnetb0clas '
54+
'--output efficientnetb0clas.onnx --opset 12'.split(),
55+
capture_output=True)
56+
print(proc.returncode)
57+
print(proc.stdout.decode('ascii'))
58+
print(proc.stderr.decode('ascii'))
59+
60+
########################################
61+
# Runs onnxruntime.
62+
session = InferenceSession("efficientnetb0clas.onnx")
63+
got = session.run(None, {'input_1:0': input})
64+
print(got[0])
65+
66+
########################################
67+
# Measures the differences.
68+
print(np.abs(got[0] - expected).max())
69+
70+
########################################
71+
# Measures processing time.
72+
print('tf:', timeit.timeit('model.predict(input)',
73+
number=10, globals=globals()))
74+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
75+
number=10, globals=globals()))

examples/end2end_tfkeras.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
This example builds a simple model without training.
3+
It is converted into ONNX. Predictions are compared to
4+
the predictions from tensorflow to check there is no
5+
discrepencies. Inferencing time is also compared between
6+
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
7+
"""
8+
from onnxruntime import InferenceSession
9+
import os
10+
import subprocess
11+
import timeit
12+
import numpy as np
13+
import tensorflow as tf
14+
from tensorflow import keras
15+
from tensorflow.keras import layers, Input
16+
17+
########################################
18+
# Creates the model.
19+
model = keras.Sequential()
20+
model.add(Input((4, 4)))
21+
model.add(layers.SimpleRNN(8))
22+
model.add(layers.Dense(2))
23+
print(model.summary())
24+
input_names = [n.name for n in model.inputs]
25+
output_names = [n.name for n in model.outputs]
26+
print('inputs:', input_names)
27+
print('outputs:', output_names)
28+
29+
########################################
30+
# Training
31+
# ....
32+
# Skipped.
33+
34+
########################################
35+
# Testing the model.
36+
input = np.random.randn(2, 4, 4).astype(np.float32)
37+
expected = model.predict(input)
38+
print(expected)
39+
40+
########################################
41+
# Saves the model.
42+
if not os.path.exists("simple_rnn"):
43+
os.mkdir("simple_rnn")
44+
tf.keras.models.save_model(model, "simple_rnn")
45+
46+
########################################
47+
# Run the command line.
48+
proc = subprocess.run('python -m tf2onnx.convert --saved-model simple_rnn '
49+
'--output simple_rnn.onnx --opset 12'.split(),
50+
capture_output=True)
51+
print(proc.returncode)
52+
print(proc.stdout.decode('ascii'))
53+
print(proc.stderr.decode('ascii'))
54+
55+
########################################
56+
# Runs onnxruntime.
57+
session = InferenceSession("simple_rnn.onnx")
58+
got = session.run(None, {'input_1:0': input})
59+
print(got[0])
60+
61+
########################################
62+
# Measures the differences.
63+
print(np.abs(got[0] - expected).max())
64+
65+
########################################
66+
# Measures processing time.
67+
print('tf:', timeit.timeit('model.predict(input)',
68+
number=100, globals=globals()))
69+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
70+
number=100, globals=globals()))

tests/test_example.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Test examples."""
5+
6+
import os
7+
import subprocess
8+
import unittest
9+
from common import check_opset_min_version, check_opset_max_version, check_tf_min_version
10+
11+
12+
class TestExample(unittest.TestCase):
13+
"""test examples"""
14+
15+
def run_example(self, name, expected=None):
16+
"Executes one example."
17+
full = os.path.join(
18+
os.path.abspath(os.path.dirname(__file__)),
19+
"..", "examples", name)
20+
if not os.path.exists(full):
21+
raise FileNotFoundError(full)
22+
proc = subprocess.run(('python %s' % full).split(),
23+
capture_output=True, check=True)
24+
self.assertEqual(0, proc.returncode)
25+
out = proc.stdout.decode('ascii')
26+
if 'tensorflow_hub not installed' in out:
27+
return
28+
err = proc.stderr.decode('ascii')
29+
self.assertTrue(err is not None)
30+
if expected is not None:
31+
for exp in expected:
32+
self.assertIn(exp, out)
33+
34+
@check_tf_min_version("2.3", "use tf.keras")
35+
@check_opset_min_version(12)
36+
@check_opset_max_version(13)
37+
def test_end2end_tfkeras(self):
38+
self.run_example(
39+
"end2end_tfkeras.py",
40+
expected=["ONNX model is saved at simple_rnn.onnx",
41+
"Optimizing ONNX model",
42+
"Using opset <onnx, 12>"])
43+
44+
@check_tf_min_version("2.3", "use tf.keras")
45+
@check_opset_min_version(12)
46+
@check_opset_max_version(13)
47+
def test_end2end_tfhub(self):
48+
self.run_example(
49+
"end2end_tfhub.py",
50+
expected=["ONNX model is saved at efficientnetb0clas.onnx",
51+
"Optimizing ONNX model",
52+
"Using opset <onnx, 12>"])
53+
54+
55+
if __name__ == '__main__':
56+
unittest.main()

0 commit comments

Comments
 (0)