11import torch
2- from enum import Enum
2+ from enum import IntEnum
33from skimage import io
4- from skimage import color
54import numpy as np
65
76from .utils import *
87
98
10- class LandmarksType(Enum ):
9+ class LandmarksType(IntEnum ):
1110 """Enum class defining the type of landmarks to detect.
1211
1312 ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
@@ -20,26 +19,32 @@ class LandmarksType(Enum):
2019 _3D = 3
2120
2221
23- class NetworkSize(Enum ):
22+ class NetworkSize(IntEnum ):
2423 # TINY = 1
2524 # SMALL = 2
2625 # MEDIUM = 3
2726 LARGE = 4
2827
29- def __new__(cls, value):
30- member = object.__new__(cls)
31- member._value_ = value
32- return member
3328
34- def __int__(self):
35- return self.value
36-
37- models_urls = {
29+ default_model_urls = {
3830 '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip',
3931 '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-4a694010b9.zip',
4032 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-6c4283c0e0.zip',
4133}
4234
35+ models_urls = {
36+ '1.6': {
37+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.6-c827573f02.zip',
38+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.6-ec5cf40a1d.zip',
39+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.6-2aa3f18772.zip',
40+ },
41+ '1.5': {
42+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.5-a60332318a.zip',
43+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.5-176570af4d.zip',
44+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.5-bc10f98e39.zip',
45+ },
46+ }
47+
4348
4449class FaceAlignment:
4550 def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
@@ -50,6 +55,11 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
5055 self.verbose = verbose
5156
5257 network_size = int(network_size)
58+ pytorch_version = torch.__version__
59+ if 'dev' in pytorch_version:
60+ pytorch_version = pytorch_version.rsplit('.', 2)[0]
61+ else:
62+ pytorch_version = pytorch_version.rsplit('.', 1)[0]
5363
5464 if 'cuda' in device:
5565 torch.backends.cudnn.benchmark = True
@@ -64,14 +74,16 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
6474 network_name = '2DFAN-' + str(network_size)
6575 else:
6676 network_name = '3DFAN-' + str(network_size)
67- self.face_alignment_net = torch.jit.load(load_file_from_url(models_urls[network_name]))
77+ self.face_alignment_net = torch.jit.load(
78+ load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name]))
6879
6980 self.face_alignment_net.to(device)
7081 self.face_alignment_net.eval()
7182
7283 # Initialiase the depth prediciton network
7384 if landmarks_type == LandmarksType._3D:
74- self.depth_prediciton_net = torch.jit.load(load_file_from_url(models_urls['depth']))
85+ self.depth_prediciton_net = torch.jit.load(
86+ load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))
7587
7688 self.depth_prediciton_net.to(device)
7789 self.depth_prediciton_net.eval()
0 commit comments