Skip to content

Commit d9abc4c

Browse files
committed
add one more example
Signed-off-by: xavier dupré <[email protected]>
1 parent 10c283c commit d9abc4c

File tree

3 files changed

+122
-4
lines changed

3 files changed

+122
-4
lines changed

examples/end2end_tfhub.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
This example retrieves a model from tensorflowhub.
3+
It is converted into ONNX. Predictions are compared to
4+
the predictions from tensorflow to check there is no
5+
discrepencies. Inferencing time is also compared between
6+
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
7+
"""
8+
from onnxruntime import InferenceSession
9+
import os
10+
import subprocess
11+
import timeit
12+
import numpy as np
13+
import tensorflow as tf
14+
from tensorflow import keras
15+
from tensorflow.keras import layers, Input
16+
from tensorflow.python.saved_model import tag_constants
17+
import tensorflow_hub as tfhub
18+
19+
########################################
20+
# Downloads the model.
21+
hub_layer = tfhub.KerasLayer(
22+
"https://tfhub.dev/google/efficientnet/b0/classification/1")
23+
model = keras.Sequential()
24+
model.add(tf.keras.Input(shape=(224, 224, 3), dtype=tf.float32))
25+
model.add(hub_layer)
26+
print(model.summary())
27+
28+
########################################
29+
# Saves the model.
30+
if not os.path.exists("efficientnetb0clas"):
31+
os.mkdir("efficientnetb0clas")
32+
tf.keras.models.save_model(model, "efficientnetb0clas")
33+
34+
input_names = [n.name for n in model.inputs]
35+
output_names = [n.name for n in model.outputs]
36+
print('inputs:', input_names)
37+
print('outputs:', output_names)
38+
39+
########################################
40+
# Testing the model.
41+
input = np.random.randn(2, 224, 224, 3).astype(np.float32)
42+
expected = model.predict(input)
43+
print(expected)
44+
45+
########################################
46+
# Run the command line.
47+
proc = subprocess.run(
48+
'python -m tf2onnx.convert --saved-model efficientnetb0clas '
49+
'--output efficientnetb0clas.onnx --opset 12'.split(),
50+
capture_output=True)
51+
print(proc.returncode)
52+
print(proc.stdout.decode('ascii'))
53+
print(proc.stderr.decode('ascii'))
54+
55+
########################################
56+
# Runs onnxruntime.
57+
session = InferenceSession("efficientnetb0clas.onnx")
58+
got = session.run(None, {'input_1:0': input})
59+
print(got[0])
60+
61+
########################################
62+
# Measures the differences.
63+
print(np.abs(got[0] - expected).max())
64+
65+
########################################
66+
# Measures processing time.
67+
print('tf:', timeit.timeit('model.predict(input)',
68+
number=10, globals=globals()))
69+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
70+
number=10, globals=globals()))
71+
72+
########################################
73+
# Freezes the graph with tensorflow.lite.
74+
converter = tf.lite.TFLiteConverter.from_saved_model("efficientnetb0clas")
75+
tflite_model = converter.convert()
76+
with open("efficientnetb0clas.tflite", "wb") as f:
77+
f.write(tflite_model)
78+
79+
# Builds an interpreter.
80+
interpreter = tf.lite.Interpreter(model_path='efficientnetb0clas.tflite')
81+
interpreter.allocate_tensors()
82+
input_details = interpreter.get_input_details()
83+
output_details = interpreter.get_output_details()
84+
print("input_details", input_details)
85+
print("output_details", output_details)
86+
index = input_details[0]['index']
87+
88+
89+
def tflite_predict(input, interpreter=interpreter, index=index):
90+
res = []
91+
for i in range(input.shape[0]):
92+
interpreter.set_tensor(index, input[i:i + 1])
93+
interpreter.invoke()
94+
res.append(interpreter.get_tensor(output_details[0]['index']))
95+
return np.vstack(res)
96+
97+
98+
print(input[0:1].shape, "----", input_details[0]['shape'])
99+
output_data = tflite_predict(input, interpreter, index)
100+
print(output_data)
101+
102+
########################################
103+
# Measures processing time again.
104+
105+
print('tf:', timeit.timeit('model.predict(input)',
106+
number=10, globals=globals()))
107+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
108+
number=10, globals=globals()))
109+
print('tflite:', timeit.timeit('tflite_predict(input)',
110+
number=10, globals=globals()))

examples/end2end_tfkeras.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
from tensorflow.keras import layers, Input
1616
from tensorflow.python.saved_model import tag_constants
1717
from tensorflow.python.tools import freeze_graph
18-
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, freeze_session
1918

2019
########################################
2120
# Creates the model.
2221
model = keras.Sequential()
23-
#model.add(layers.Embedding(input_dim=10, output_dim=4))
2422
model.add(Input((4, 4)))
2523
model.add(layers.SimpleRNN(8))
2624
model.add(layers.Dense(2))
@@ -74,13 +72,13 @@
7472
number=100, globals=globals()))
7573

7674
########################################
77-
# Freezes the graph with tensorflow.lite
75+
# Freezes the graph with tensorflow.lite.
7876
converter = tf.lite.TFLiteConverter.from_saved_model("simple_rnn")
7977
tflite_model = converter.convert()
8078
with open("simple_rnn.tflite", "wb") as f:
8179
f.write(tflite_model)
8280

83-
# Builds an interpreter
81+
# Builds an interpreter.
8482
interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
8583
interpreter.allocate_tensors()
8684
input_details = interpreter.get_input_details()

tests/test_example.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def test_end2end_tfkeras(self):
3939
"Optimizing ONNX model",
4040
"Using opset <onnx, 12>"])
4141

42+
@check_tf_min_version("2.3", "use tf.keras")
43+
@check_opset_min_version(12)
44+
@check_opset_max_version(13)
45+
def test_end2end_tfhub(self):
46+
self.run_example(
47+
"end2end_tfhub.py",
48+
expected=["ONNX model is saved at efficientnetb0clas.onnx",
49+
"Optimizing ONNX model",
50+
"Using opset <onnx, 12>"])
51+
4252

4353
if __name__ == '__main__':
4454
unittest.main()

0 commit comments

Comments
 (0)