Skip to content

Commit 1aec83c

Browse files
committed
intermediate results in tflite
Signed-off-by: xavier dupré <[email protected]>
1 parent d6c3ebf commit 1aec83c

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tests/tfhub/_tools.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,15 @@ def call_tflite(inp):
403403
with open(onnx_name, "rb") as f:
404404
model_onnx = onnx.load(f)
405405

406-
call_tflite(imgs[0])
406+
interpreter_details = tf.lite.Interpreter(tname, experimental_preserve_all_tensors=True)
407+
input_details = interpreter_details.get_input_details()
408+
index_in = input_details[0]['index']
409+
interpreter_details.allocate_tensors()
410+
interpreter_details.set_tensor(index_in, imgs[0])
411+
interpreter_details.invoke()
412+
details = interpreter_details.get_tensor_details()
413+
407414
inputs = {input_name: imgs[0]}
408-
details = interpreter.get_tensor_details()
409415
names_index = {}
410416
for tt in details:
411417
names_index[tt['name']] = (tt['index'], tt['quantization'], tt['quantization_parameters'])
@@ -414,7 +420,7 @@ def call_tflite(inp):
414420
for name_tfl, name_ort in names:
415421
index = names_index[name_tfl]
416422

417-
tfl_value = interpreter.get_tensor(index[0])
423+
tfl_value = interpreter_details.get_tensor(index[0])
418424

419425
new_name = onnx_name + ".%s.onnx" % name_ort.replace(":", "_").replace(";", "_").replace("/", "_")
420426
if not os.path.exists(new_name):

0 commit comments

Comments
 (0)