Skip to content

Commit 6f667ef

Browse files
committed
Add tensorflow lite to the example.
Signed-off-by: xavier dupré <[email protected]>
1 parent 8eb9c28 commit 6f667ef

File tree

1 file changed

+82
-8
lines changed

1 file changed

+82
-8
lines changed

examples/end2end_tfkeras.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,52 @@
1+
"""
2+
This example builds a simple model without training.
3+
It is converted into ONNX. Predictions are compared to
4+
the predictions from tensorflow to check there is no
5+
discrepencies. Inferencing time is also compared between
6+
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
7+
"""
18
import os
29
import subprocess
10+
import timeit
311
import numpy as np
412
import tensorflow as tf
513
from tensorflow import keras
614
from tensorflow.keras import layers, Input
15+
from tensorflow.python.saved_model import tag_constants
16+
from tensorflow.python.tools import freeze_graph
17+
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, freeze_session
718

8-
19+
########################################
920
# Creates the model.
1021
model = keras.Sequential()
1122
#model.add(layers.Embedding(input_dim=10, output_dim=4))
1223
model.add(Input((4, 4)))
1324
model.add(layers.SimpleRNN(8))
1425
model.add(layers.Dense(2))
1526
print(model.summary())
16-
print(model.inputs)
17-
print(model.outputs)
27+
input_names = [n.name for n in model.inputs]
28+
output_names = [n.name for n in model.outputs]
29+
print('inputs:', input_names)
30+
print('outputs:', output_names)
31+
32+
########################################
33+
# Training
34+
# ....
35+
# Skipped.
1836

37+
########################################
1938
# Testing the model.
2039
input = np.random.randn(2, 4, 4).astype(np.float32)
2140
expected = model.predict(input)
2241
print(expected)
2342

24-
# Training
25-
# ....
26-
43+
########################################
2744
# Saves the model.
2845
if not os.path.exists("simple_rnn"):
2946
os.mkdir("simple_rnn")
3047
tf.keras.models.save_model(model, "simple_rnn")
3148

49+
########################################
3250
# Run the command line.
3351
proc = subprocess.run('python -m tf2onnx.convert --saved-model simple_rnn '
3452
'--output simple_rnn.onnx --opset 12'.split(),
@@ -37,11 +55,67 @@
3755
print(proc.stdout.decode('ascii'))
3856
print(proc.stderr.decode('ascii'))
3957

40-
# Run onnxruntime.
58+
########################################
59+
# Runs onnxruntime.
4160
from onnxruntime import InferenceSession
4261
session = InferenceSession("simple_rnn.onnx")
4362
got = session.run(None, {'input_1:0': input})
4463
print(got[0])
4564

46-
# Differences
65+
########################################
66+
# Measures the differences.
4767
print(np.abs(got[0] - expected).max())
68+
69+
########################################
70+
# Measures processing time.
71+
print('tf:', timeit.timeit('model.predict(input)',
72+
number=100, globals=globals()))
73+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
74+
number=100, globals=globals()))
75+
76+
########################################
77+
# Freezes the graph with tensorflow.lite
78+
converter = tf.lite.TFLiteConverter.from_saved_model("simple_rnn")
79+
tflite_model = converter.convert()
80+
with open("simple_rnn.tflite", "wb") as f:
81+
f.write(tflite_model)
82+
83+
# Builds an interpreter
84+
interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
85+
interpreter.allocate_tensors()
86+
input_details = interpreter.get_input_details()
87+
output_details = interpreter.get_output_details()
88+
print("input_details", input_details)
89+
print("output_details", output_details)
90+
index = input_details[0]['index']
91+
92+
def tflite_predict(input, interpreter=interpreter, index=index):
93+
res = []
94+
for i in range(input.shape[0]):
95+
interpreter.set_tensor(index, input[i:i+1])
96+
interpreter.invoke()
97+
res.append(interpreter.get_tensor(output_details[0]['index']))
98+
return np.vstack(res)
99+
100+
print(input[0:1].shape, "----", input_details[0]['shape'])
101+
output_data = tflite_predict(input, interpreter, index)
102+
print(output_data)
103+
104+
########################################
105+
# Measures processing time again.
106+
107+
print('tf:', timeit.timeit('model.predict(input)',
108+
number=100, globals=globals()))
109+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
110+
number=100, globals=globals()))
111+
print('tflite:', timeit.timeit('tflite_predict(input)',
112+
number=100, globals=globals()))
113+
114+
########################################
115+
# Measures processing time only between onnxruntime and
116+
# tensorflow lite with more loops.
117+
118+
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
119+
number=10000, globals=globals()))
120+
print('tflite:', timeit.timeit('tflite_predict(input)',
121+
number=10000, globals=globals()))

0 commit comments

Comments
 (0)