-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathadaface.py
More file actions
87 lines (72 loc) · 3.34 KB
/
adaface.py
File metadata and controls
87 lines (72 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Adapter for the CVLface model from
<https://github.com/mk-minchul/CVLface?tab=readme-ov-file#-state-of-the-art-performance>
with
- "Arch: ViT KPRPE"
- "Loss: AdaFace"
- (Training) "Dataset: WebFace12M"
linked at/provided via
<https://huggingface.co/minchul/cvlface_adaface_vit_base_kprpe_webface12m>.
The corresponding face detector / facial landmark detector / alignment model at
<https://huggingface.co/minchul/cvlface_DFA_mobilenet>
is not required by the CvlfaceFrAdapterMode.PURE_FIQAT_PREP
(which is currently the only mode included here).
Please take a look at the linked sites regarding the corresponding papers and license information.
The Python package dependencies (in addition to the FIQAT setup)
can be installed via the cvlface_requirements.txt file,
but the same requirement specifications are already included
in the current fiqat_ts_pa_ext pyproject.toml as well
(regarding e.g. "pip install --editable .").
The model files will be downloaded automatically via the
huggingface_download
function, but the destination path has to be specified via either
- the CvlfaceFrAdapter cvlface_dir __init__ parameter,
- or by adding a corresponding entry to the FIQAT .toml configuration file, e.g.:
[cvlface]
models = "my/local/cvlface/model/directory/path"
"""
from transformers import AutoModel
from huggingface_hub import hf_hub_download
import shutil
import os
import sys
# helper function to download huggingface repo and use model
def download(repo_id, path, HF_TOKEN=None):
os.makedirs(path, exist_ok=True)
files_path = os.path.join(path, 'files.txt')
if not os.path.exists(files_path):
hf_hub_download(repo_id, 'files.txt', token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)
with open(os.path.join(path, 'files.txt'), 'r') as f:
files = f.read().split('\n')
for file in [f for f in files if f] + ['config.json', 'wrapper.py', 'model.safetensors']:
full_path = os.path.join(path, file)
if not os.path.exists(full_path):
hf_hub_download(repo_id, file, token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)
# helper function to download huggingface repo and use model
def load_model_from_local_path(path, HF_TOKEN=None):
cwd = os.getcwd()
os.chdir(path)
sys.path.insert(0, path)
model = AutoModel.from_pretrained(path, trust_remote_code=True, token=HF_TOKEN)
os.chdir(cwd)
sys.path.pop(0)
return model
# helper function to download huggingface repo and use model
def load_model_by_repo_id(repo_id, save_path, HF_TOKEN=None, force_download=False):
if force_download:
if os.path.exists(save_path):
shutil.rmtree(save_path)
download(repo_id, save_path, HF_TOKEN)
return load_model_from_local_path(save_path, HF_TOKEN)
if __name__ == '__main__':
HF_TOKEN = '---YOUR TOKEN---'
path = os.path.expanduser('~/.cvlface_cache/minchul/cvlface_adaface_ir101_webface12m')
repo_id = 'minchul/cvlface_adaface_ir101_webface12m'
model = load_model_by_repo_id(repo_id, path, HF_TOKEN)
# input is a rgb image normalized.
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image
img = Image.open('---Path to test image---')
trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
input = trans(img).unsqueeze(0) # torch.randn(1, 3, 112, 112)
out = model(input)