Skip to content

Commit 4becde0

Browse files
Add tflite support to run_pretrained_models.py (#1313)
* Add tflite support to run_pretrained_models.py Signed-off-by: Tom Wildenhain <[email protected]> * added tflite pretrained model Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 43d21a7 commit 4becde0

File tree

5 files changed

+137
-45
lines changed

5 files changed

+137
-45
lines changed

ci_build/azure_pipelines/pretrained_model_test.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# Pre-trained model test
22

33
jobs:
4+
- template: 'templates/job_generator.yml'
5+
parameters:
6+
python_versions: ['3.7']
7+
tf_versions: ['2.4.1']
8+
skip_tflite_tests: 'False'
9+
skip_tf_tests: 'True'
10+
job:
11+
steps:
12+
- template: 'pretrained_model_test.yml'
13+
414
- template: 'templates/job_generator.yml'
515
parameters:
616
python_versions: ['3.7']

ci_build/azure_pipelines/templates/pretrained_model_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ steps:
66
status=0
77
# TODO: fix unity model path
88
# python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$?
9-
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/run_pretrained_models.yaml || status=$?
9+
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --config tests/run_pretrained_models.yaml || status=$?
1010
exit $status
1111
displayName: 'Test Pre-trained Model'

tests/car.JPEG

29.9 KB
Loading

tests/run_pretrained_models.py

Lines changed: 107 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,27 @@
5656
PERFITER = 1000
5757

5858

59-
def get_beach(shape):
60-
"""Get beach image as input."""
59+
def get_img(shape, path, dtype, should_scale=True):
60+
"""Get image as input."""
6161
resize_to = shape[1:3]
62-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "beach.jpg")
62+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
6363
img = PIL.Image.open(path)
6464
img = img.resize(resize_to, PIL.Image.ANTIALIAS)
65-
img_np = np.array(img).astype(np.float32)
65+
img_np = np.array(img).astype(dtype)
6666
img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape)
67-
return img_np / 255
67+
if should_scale:
68+
img_np = img_np / 255
69+
return img_np
70+
71+
72+
def get_beach(shape):
73+
"""Get beach image as input."""
74+
return get_img(shape, "beach.jpg", np.float32, should_scale=True)
75+
76+
77+
def get_car(shape):
78+
"""Get car image as input."""
79+
return get_img(shape, "car.JPEG", np.float32, should_scale=True)
6880

6981

7082
def get_random(shape):
@@ -133,6 +145,7 @@ def get_sentence():
133145

134146
_INPUT_FUNC_MAPPING = {
135147
"get_beach": get_beach,
148+
"get_car": get_car,
136149
"get_random": get_random,
137150
"get_random256": get_random256,
138151
"get_ramp": get_ramp,
@@ -219,6 +232,9 @@ def download_model(self):
219232
elif url.endswith('.zip'):
220233
ftype = 'zip'
221234
dir_name = fname.replace(".zip", "")
235+
elif url.endswith('.tflite'):
236+
ftype = 'tflite'
237+
dir_name = fname.replace(".tflite", "")
222238
dir_name = os.path.join(cache_dir, dir_name)
223239
os.makedirs(dir_name, exist_ok=True)
224240
fpath = os.path.join(dir_name, fname)
@@ -266,7 +282,7 @@ def run_tensorflow(self, sess, inputs):
266282
return result
267283

268284
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
269-
const_node_values=None, initialized_tables=None):
285+
const_node_values=None, initialized_tables=None, tflite_path=None):
270286
"""Convert graph to tensorflow."""
271287
if extra_opset is None:
272288
extra_opset = []
@@ -275,8 +291,8 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
275291
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
276292
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
277293
input_names=input_names, output_names=self.output_names,
278-
const_node_values=const_node_values,
279-
initialized_tables=initialized_tables)
294+
const_node_values=const_node_values, initialized_tables=initialized_tables,
295+
tflite_path=tflite_path)
280296

281297
def run_caffe2(self, name, model_proto, inputs):
282298
"""Run test again caffe2 backend."""
@@ -340,6 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
340356
input_names = list(self.input_names.keys())
341357
initialized_tables = {}
342358
outputs = self.output_names
359+
tflite_path = None
343360
if self.model_type in ["checkpoint"]:
344361
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
345362
elif self.model_type in ["saved_model"]:
@@ -355,12 +372,43 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
355372
graph_def, input_names, outputs, initialized_tables = loaded
356373
elif self.model_type in ["keras"]:
357374
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
375+
elif self.model_type in ["tflite"]:
376+
tflite_path = model_path
377+
graph_def = None
358378
else:
359379
graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs)
360380

361381
if utils.is_debug_mode():
362382
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
363383

384+
if tflite_path is not None:
385+
inputs = {}
386+
for k in input_names:
387+
v = self.input_names[k]
388+
inputs[k] = self.make_input(v)
389+
390+
interpreter = tf.lite.Interpreter(tflite_path)
391+
input_details = interpreter.get_input_details()
392+
output_details = interpreter.get_output_details()
393+
input_name_to_index = {n['name'].split(':')[0]: n['index'] for n in input_details}
394+
for k, v in inputs.items():
395+
interpreter.resize_tensor_input(input_name_to_index[k], v.shape)
396+
interpreter.allocate_tensors()
397+
def run_tflite():
398+
for k, v in inputs.items():
399+
interpreter.set_tensor(input_name_to_index[k], v)
400+
interpreter.invoke()
401+
result = [interpreter.get_tensor(output['index']) for output in output_details]
402+
return result
403+
tf_results = run_tflite()
404+
if self.perf:
405+
logger.info("Running TFLite perf")
406+
start = time.time()
407+
for _ in range(PERFITER):
408+
_ = run_tflite()
409+
self.tf_runtime = time.time() - start
410+
logger.info("TFLite OK")
411+
364412
if not self.run_tf_frozen:
365413
inputs = {}
366414
for k in input_names:
@@ -384,45 +432,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
384432
self.tf_runtime = time.time() - start
385433
logger.info("TensorFlow OK")
386434

387-
inputs = {}
388435
shape_override = {}
389-
tf_reset_default_graph()
390-
391-
from tf2onnx.tf_utils import compress_graph_def
392436
const_node_values = None
393-
with tf.Graph().as_default() as tf_graph:
394-
if self.large_model:
395-
const_node_values = compress_graph_def(graph_def)
396-
tf.import_graph_def(graph_def, name='')
437+
tf_graph = None
397438

398-
with tf_session(graph=tf_graph) as sess:
399-
# create the input data
400-
for k in input_names:
401-
v = self.input_names[k]
402-
t = sess.graph.get_tensor_by_name(k)
403-
expected_dtype = tf.as_dtype(t.dtype).name
404-
if isinstance(v, six.text_type) and v.startswith("np."):
405-
np_value = eval(v) # pylint: disable=eval-used
406-
if expected_dtype != np_value.dtype:
407-
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
408-
np_value.dtype)
409-
inputs[k] = np_value.astype(expected_dtype)
410-
else:
411-
if expected_dtype == "string":
412-
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
439+
if graph_def is not None:
440+
inputs = {}
441+
tf_reset_default_graph()
442+
443+
with tf.Graph().as_default() as tf_graph:
444+
from tf2onnx.tf_utils import compress_graph_def
445+
if self.large_model:
446+
const_node_values = compress_graph_def(graph_def)
447+
tf.import_graph_def(graph_def, name='')
448+
449+
with tf_session(graph=tf_graph) as sess:
450+
# create the input data
451+
for k in input_names:
452+
v = self.input_names[k]
453+
t = sess.graph.get_tensor_by_name(k)
454+
expected_dtype = tf.as_dtype(t.dtype).name
455+
if isinstance(v, six.text_type) and v.startswith("np."):
456+
np_value = eval(v) # pylint: disable=eval-used
457+
if expected_dtype != np_value.dtype:
458+
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
459+
np_value.dtype)
460+
inputs[k] = np_value.astype(expected_dtype)
413461
else:
414-
inputs[k] = self.make_input(v).astype(expected_dtype)
462+
if expected_dtype == "string":
463+
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
464+
else:
465+
inputs[k] = self.make_input(v).astype(expected_dtype)
415466

416-
if self.force_input_shape:
417-
for k, v in inputs.items():
418-
shape_override[k] = list(v.shape)
467+
if self.force_input_shape:
468+
for k, v in inputs.items():
469+
shape_override[k] = list(v.shape)
470+
471+
# run the model with tensorflow
472+
if self.skip_tensorflow:
473+
logger.info("TensorFlow SKIPPED")
474+
elif self.run_tf_frozen:
475+
tf_results = self.run_tensorflow(sess, inputs)
476+
logger.info("TensorFlow OK")
477+
tf_graph = sess.graph
419478

420-
# run the model with tensorflow
421-
if self.skip_tensorflow:
422-
logger.info("TensorFlow SKIPPED")
423-
elif self.run_tf_frozen:
424-
tf_results = self.run_tensorflow(sess, inputs)
425-
logger.info("TensorFlow OK")
426479

427480
model_proto = None
428481
if self.skip_conversion:
@@ -436,10 +489,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
436489
else:
437490
try:
438491
# convert model to onnx
439-
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
492+
onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset,
440493
shape_override=shape_override, input_names=inputs.keys(),
441494
const_node_values=const_node_values,
442-
initialized_tables=initialized_tables)
495+
initialized_tables=initialized_tables, tflite_path=tflite_path)
443496
onnx_graph = optimizer.optimize_graph(onnx_graph)
444497
print("ONNX", onnx_graph.dump_node_statistics())
445498
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
@@ -538,6 +591,8 @@ def get_args():
538591
parser.add_argument("--opset", type=int, default=None, help="opset to use")
539592
parser.add_argument("--extra_opset", default=None,
540593
help="extra opset with format like domain:version, e.g. com.microsoft:1")
594+
parser.add_argument("--skip_tf_tests", help="skip non-tflite tests", default="False")
595+
parser.add_argument("--skip_tflite_tests", help="skip tflite tests", default="False")
541596
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
542597
parser.add_argument("--debug", help="debug mode", action="store_true")
543598
parser.add_argument("--list", help="list tests", action="store_true")
@@ -550,6 +605,8 @@ def get_args():
550605
args = parser.parse_args()
551606

552607
args.target = args.target.split(",")
608+
args.skip_tf_tests = args.skip_tf_tests.upper() == "TRUE"
609+
args.skip_tflite_tests = args.skip_tflite_tests.upper() == "TRUE"
553610
if args.extra_opset:
554611
tokens = args.extra_opset.split(':')
555612
if len(tokens) != 2:
@@ -644,6 +701,13 @@ def main():
644701
logger.info("Skip %s: disabled", test)
645702
continue
646703

704+
if args.skip_tflite_tests and t.model_type == "tflite":
705+
logger.info("Skip %s: tflite test", test)
706+
continue
707+
if args.skip_tf_tests and t.model_type != "tflite":
708+
logger.info("Skip %s: not tflite test", test)
709+
continue
710+
647711
condition, reason = t.check_opset_constraints(args.opset, args.extra_opset)
648712
if not condition:
649713
logger.info("Skip %s: %s", test, reason)

tests/run_pretrained_models.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,4 +435,22 @@ keras_mobilenet_v2:
435435
inputs:
436436
"input_1:0": [1, 224, 224, 3]
437437
outputs:
438-
- Identity:0
438+
- Identity:0
439+
440+
ssd_mobilenet_v2_300_float_tflite:
441+
tf_min_version: 2.1
442+
disabled: false
443+
url: https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/ssd_mobilenet_v2_300_float.tflite
444+
model: "ssd_mobilenet_v2_300_float.tflite"
445+
model_type: tflite
446+
input_get: get_car
447+
opset_constraints:
448+
"onnx":
449+
"min": 11
450+
inputs:
451+
"normalized_input_image_tensor": [1, 300, 300, 3]
452+
outputs:
453+
- TFLite_Detection_PostProcess
454+
- TFLite_Detection_PostProcess:1
455+
- TFLite_Detection_PostProcess:2
456+
- TFLite_Detection_PostProcess:3

0 commit comments

Comments
 (0)