|
5 | 5 | discrepencies. Inferencing time is also compared between
|
6 | 6 | *onnxruntime*, *tensorflow* and *tensorflow.lite*.
|
7 | 7 | """
|
| 8 | +from onnxruntime import InferenceSession |
8 | 9 | import os
|
9 | 10 | import subprocess
|
10 | 11 | import timeit
|
|
37 | 38 | ########################################
|
38 | 39 | # Testing the model.
|
39 | 40 | input = np.random.randn(2, 4, 4).astype(np.float32)
|
40 |
| -expected = model.predict(input) |
| 41 | +expected = model.predict(input) |
41 | 42 | print(expected)
|
42 | 43 |
|
43 | 44 | ########################################
|
|
57 | 58 |
|
58 | 59 | ########################################
|
59 | 60 | # Runs onnxruntime.
|
60 |
| -from onnxruntime import InferenceSession |
61 | 61 | session = InferenceSession("simple_rnn.onnx")
|
62 | 62 | got = session.run(None, {'input_1:0': input})
|
63 | 63 | print(got[0])
|
|
79 | 79 | tflite_model = converter.convert()
|
80 | 80 | with open("simple_rnn.tflite", "wb") as f:
|
81 | 81 | f.write(tflite_model)
|
82 |
| - |
| 82 | + |
83 | 83 | # Builds an interpreter
|
84 | 84 | interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite')
|
85 | 85 | interpreter.allocate_tensors()
|
|
89 | 89 | print("output_details", output_details)
|
90 | 90 | index = input_details[0]['index']
|
91 | 91 |
|
| 92 | + |
92 | 93 | def tflite_predict(input, interpreter=interpreter, index=index):
|
93 | 94 | res = []
|
94 | 95 | for i in range(input.shape[0]):
|
95 |
| - interpreter.set_tensor(index, input[i:i+1]) |
| 96 | + interpreter.set_tensor(index, input[i:i + 1]) |
96 | 97 | interpreter.invoke()
|
97 | 98 | res.append(interpreter.get_tensor(output_details[0]['index']))
|
98 | 99 | return np.vstack(res)
|
99 | 100 |
|
| 101 | + |
100 | 102 | print(input[0:1].shape, "----", input_details[0]['shape'])
|
101 | 103 | output_data = tflite_predict(input, interpreter, index)
|
102 | 104 | print(output_data)
|
|
0 commit comments