-
Notifications
You must be signed in to change notification settings - Fork 6k
Expand file tree
/
Copy pathmodel_zoo.py
More file actions
98 lines (84 loc) · 3.49 KB
/
model_zoo.py
File metadata and controls
98 lines (84 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .retinaface import *
#from .scrfd import *
from .landmark import *
from .attribute import Attribute
from .inswapper import INSwapper
from ..utils import download_onnx
__all__ = ['get_model']
class PickableInferenceSession(onnxruntime.InferenceSession):
# This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path, **kwargs):
super().__init__(model_path, **kwargs)
self.model_path = model_path
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
model_path = values['model_path']
self.__init__(model_path)
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self, **kwargs):
session = PickableInferenceSession(self.onnx_file, **kwargs)
print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
if len(outputs)>=5:
return RetinaFace(model_file=self.onnx_file, session=session)
elif input_shape[2]==192 and input_shape[3]==192:
return Landmark(model_file=self.onnx_file, session=session)
elif input_shape[2]==96 and input_shape[3]==96:
return Attribute(model_file=self.onnx_file, session=session)
elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128:
return INSwapper(model_file=self.onnx_file, session=session)
elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
#raise RuntimeError('error on model routing')
return None
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_default_providers():
return ['CUDAExecutionProvider', 'CANNExecutionProvider', 'CPUExecutionProvider']
def get_default_provider_options():
return None
def get_model(name, **kwargs):
root = kwargs.get('root', '~/.insightface')
root = os.path.expanduser(root)
model_root = osp.join(root, 'models')
allow_download = kwargs.get('download', False)
download_zip = kwargs.get('download_zip', False)
if not name.endswith('.onnx'):
model_dir = os.path.join(model_root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
if not osp.exists(model_file) and allow_download:
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
assert osp.exists(model_file), 'model_file %s should exist'%model_file
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
router = ModelRouter(model_file)
providers = kwargs.get('providers', get_default_providers())
provider_options = kwargs.get('provider_options', get_default_provider_options())
model = router.get_model(providers=providers, provider_options=provider_options)
return model