|
68 | 68 | number=100, globals=globals()))
|
69 | 69 | print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
|
70 | 70 | number=100, globals=globals()))
|
71 |
| - |
72 |
| -######################################## |
73 |
| -# Freezes the graph with tensorflow.lite. |
74 |
| -converter = tf.lite.TFLiteConverter.from_saved_model("simple_rnn") |
75 |
| -tflite_model = converter.convert() |
76 |
| -with open("simple_rnn.tflite", "wb") as f: |
77 |
| - f.write(tflite_model) |
78 |
| - |
79 |
| -# Builds an interpreter. |
80 |
| -interpreter = tf.lite.Interpreter(model_path='simple_rnn.tflite') |
81 |
| -interpreter.allocate_tensors() |
82 |
| -input_details = interpreter.get_input_details() |
83 |
| -output_details = interpreter.get_output_details() |
84 |
| -print("input_details", input_details) |
85 |
| -print("output_details", output_details) |
86 |
| -index = input_details[0]['index'] |
87 |
| - |
88 |
| - |
89 |
| -def tflite_predict(input, interpreter=interpreter, index=index): |
90 |
| - res = [] |
91 |
| - for i in range(input.shape[0]): |
92 |
| - interpreter.set_tensor(index, input[i:i + 1]) |
93 |
| - interpreter.invoke() |
94 |
| - res.append(interpreter.get_tensor(output_details[0]['index'])) |
95 |
| - return np.vstack(res) |
96 |
| - |
97 |
| - |
98 |
| -print(input[0:1].shape, "----", input_details[0]['shape']) |
99 |
| -output_data = tflite_predict(input, interpreter, index) |
100 |
| -print(output_data) |
101 |
| - |
102 |
| -######################################## |
103 |
| -# Measures processing time again. |
104 |
| - |
105 |
| -print('tf:', timeit.timeit('model.predict(input)', |
106 |
| - number=100, globals=globals())) |
107 |
| -print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", |
108 |
| - number=100, globals=globals())) |
109 |
| -print('tflite:', timeit.timeit('tflite_predict(input)', |
110 |
| - number=100, globals=globals())) |
111 |
| - |
112 |
| -######################################## |
113 |
| -# Measures processing time only between onnxruntime and |
114 |
| -# tensorflow lite with more loops. |
115 |
| - |
116 |
| -print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", |
117 |
| - number=10000, globals=globals())) |
118 |
| -print('tflite:', timeit.timeit('tflite_predict(input)', |
119 |
| - number=10000, globals=globals())) |
0 commit comments