Skip to content

Commit 10c283c

Browse files
committed
style, fix an issue with tensorflow 2.1
Signed-off-by: xavier dupré <[email protected]>
1 parent 6f667ef commit 10c283c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

examples/end2end_tfkeras.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
discrepencies. Inferencing time is also compared between
66
*onnxruntime*, *tensorflow* and *tensorflow.lite*.
77
"""
8+
from onnxruntime import InferenceSession
89
import os
910
import subprocess
1011
import timeit
@@ -37,7 +38,7 @@
3738
########################################
3839
# Testing the model.
3940
input = np.random.randn(2, 4, 4).astype(np.float32)
40-
expected = model.predict(input)
41+
expected = model.predict(input)
4142
print(expected)
4243

4344
########################################
@@ -57,7 +58,6 @@
5758

5859
########################################
5960
# Runs onnxruntime.
60-
from onnxruntime import InferenceSession
6161
session = InferenceSession("simple_rnn.onnx")
6262
got = session.run(None, {'input_1:0': input})
6363
print(got[0])
@@ -79,7 +79,7 @@
7979
tflite_model = converter.convert()
8080
with open("simple_rnn.tflite", "wb") as f:
8181
f.write(tflite_model)
82-
82+
8383
# Builds an interpreter
8484
interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
8585
interpreter.allocate_tensors()
@@ -89,14 +89,16 @@
8989
print("output_details", output_details)
9090
index = input_details[0]['index']
9191

92+
9293
def tflite_predict(input, interpreter=interpreter, index=index):
9394
res = []
9495
for i in range(input.shape[0]):
95-
interpreter.set_tensor(index, input[i:i+1])
96+
interpreter.set_tensor(index, input[i:i + 1])
9697
interpreter.invoke()
9798
res.append(interpreter.get_tensor(output_details[0]['index']))
9899
return np.vstack(res)
99100

101+
100102
print(input[0:1].shape, "----", input_details[0]['shape'])
101103
output_data = tflite_predict(input, interpreter, index)
102104
print(output_data)

tests/test_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def run_example(self, name, expected=None):
2929
err = proc.stderr.decode('ascii')
3030
self.assertTrue(err is not None)
3131

32-
@check_tf_min_version("2.0", "use tf.keras")
32+
@check_tf_min_version("2.3", "use tf.keras")
3333
@check_opset_min_version(12)
3434
@check_opset_max_version(13)
3535
def test_end2end_tfkeras(self):

0 commit comments

Comments
 (0)