Skip to content

Commit 1bcb143

Browse files
committed
remove tflite part
Signed-off-by: xavier dupré <[email protected]>
1 parent 20a9deb commit 1bcb143

File tree

2 files changed

+0
-89
lines changed

2 files changed

+0
-89
lines changed

examples/end2end_tfhub.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,3 @@
7373
number=10, globals=globals()))
7474
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
7575
number=10, globals=globals()))
76-
77-
########################################
78-
# Freezes the graph with tensorflow.lite.
79-
converter = tf.lite.TFLiteConverter.from_saved_model("efficientnetb0clas")
80-
tflite_model = converter.convert()
81-
with open("efficientnetb0clas.tflite", "wb") as f:
82-
f.write(tflite_model)
83-
84-
# Builds an interpreter.
85-
interpreter = tf.lite.Interpreter(model_path='efficientnetb0clas.tflite')
86-
interpreter.allocate_tensors()
87-
input_details = interpreter.get_input_details()
88-
output_details = interpreter.get_output_details()
89-
print("input_details", input_details)
90-
print("output_details", output_details)
91-
index = input_details[0]['index']
92-
93-
94-
def tflite_predict(input, interpreter=interpreter, index=index):
95-
res = []
96-
for i in range(input.shape[0]):
97-
interpreter.set_tensor(index, input[i:i + 1])
98-
interpreter.invoke()
99-
res.append(interpreter.get_tensor(output_details[0]['index']))
100-
return np.vstack(res)
101-
102-
103-
print(input[0:1].shape, "----", input_details[0]['shape'])
104-
output_data = tflite_predict(input, interpreter, index)
105-
print(output_data)
106-
107-
########################################
108-
# Measures processing time again.
109-
110-
print('tf:', timeit.timeit('model.predict(input)',
111-
number=10, globals=globals()))
112-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
113-
number=10, globals=globals()))
114-
print('tflite:', timeit.timeit('tflite_predict(input)',
115-
number=10, globals=globals()))

examples/end2end_tfkeras.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -68,52 +68,3 @@
6868
number=100, globals=globals()))
6969
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
7070
number=100, globals=globals()))
71-
72-
########################################
73-
# Freezes the graph with tensorflow.lite.
74-
converter = tf.lite.TFLiteConverter.from_saved_model("simple_rnn")
75-
tflite_model = converter.convert()
76-
with open("simple_rnn.tflite", "wb") as f:
77-
f.write(tflite_model)
78-
79-
# Builds an interpreter.
80-
interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
81-
interpreter.allocate_tensors()
82-
input_details = interpreter.get_input_details()
83-
output_details = interpreter.get_output_details()
84-
print("input_details", input_details)
85-
print("output_details", output_details)
86-
index = input_details[0]['index']
87-
88-
89-
def tflite_predict(input, interpreter=interpreter, index=index):
90-
res = []
91-
for i in range(input.shape[0]):
92-
interpreter.set_tensor(index, input[i:i + 1])
93-
interpreter.invoke()
94-
res.append(interpreter.get_tensor(output_details[0]['index']))
95-
return np.vstack(res)
96-
97-
98-
print(input[0:1].shape, "----", input_details[0]['shape'])
99-
output_data = tflite_predict(input, interpreter, index)
100-
print(output_data)
101-
102-
########################################
103-
# Measures processing time again.
104-
105-
print('tf:', timeit.timeit('model.predict(input)',
106-
number=100, globals=globals()))
107-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
108-
number=100, globals=globals()))
109-
print('tflite:', timeit.timeit('tflite_predict(input)',
110-
number=100, globals=globals()))
111-
112-
########################################
113-
# Measures processing time only between onnxruntime and
114-
# tensorflow lite with more loops.
115-
116-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
117-
number=10000, globals=globals()))
118-
print('tflite:', timeit.timeit('tflite_predict(input)',
119-
number=10000, globals=globals()))

0 commit comments

Comments
 (0)