Skip to content

Commit bfb6dda

Browse files
committed
Fix #241
1 parent c99feba commit bfb6dda

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

face_alignment/api.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
2-
from enum import Enum
2+
from enum import IntEnum
33
from skimage import io
4-
from skimage import color
54
import numpy as np
65

76
from .utils import *
87

98

10-
class LandmarksType(Enum):
9+
class LandmarksType(IntEnum):
1110
"""Enum class defining the type of landmarks to detect.
1211

1312
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
@@ -20,26 +19,32 @@ class LandmarksType(Enum):
2019
_3D = 3
2120

2221

23-
class NetworkSize(Enum):
22+
class NetworkSize(IntEnum):
2423
# TINY = 1
2524
# SMALL = 2
2625
# MEDIUM = 3
2726
LARGE = 4
2827

29-
def __new__(cls, value):
30-
member = object.__new__(cls)
31-
member._value_ = value
32-
return member
3328

34-
def __int__(self):
35-
return self.value
36-
37-
models_urls = {
29+
default_model_urls = {
3830
'2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip',
3931
'3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-4a694010b9.zip',
4032
'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-6c4283c0e0.zip',
4133
}
4234

35+
models_urls = {
36+
'1.6': {
37+
'2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.6-c827573f02.zip',
38+
'3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.6-ec5cf40a1d.zip',
39+
'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.6-2aa3f18772.zip',
40+
},
41+
'1.5': {
42+
'2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.5-a60332318a.zip',
43+
'3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.5-176570af4d.zip',
44+
'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.5-bc10f98e39.zip',
45+
},
46+
}
47+
4348

4449
class FaceAlignment:
4550
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
@@ -50,6 +55,11 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
5055
self.verbose = verbose
5156

5257
network_size = int(network_size)
58+
pytorch_version = torch.__version__
59+
if 'dev' in pytorch_version:
60+
pytorch_version = pytorch_version.rsplit('.', 2)[0]
61+
else:
62+
pytorch_version = pytorch_version.rsplit('.', 1)[0]
5363

5464
if 'cuda' in device:
5565
torch.backends.cudnn.benchmark = True
@@ -64,14 +74,16 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
6474
network_name = '2DFAN-' + str(network_size)
6575
else:
6676
network_name = '3DFAN-' + str(network_size)
67-
self.face_alignment_net = torch.jit.load(load_file_from_url(models_urls[network_name]))
77+
self.face_alignment_net = torch.jit.load(
78+
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name]))
6879

6980
self.face_alignment_net.to(device)
7081
self.face_alignment_net.eval()
7182

7283
# Initialiase the depth prediciton network
7384
if landmarks_type == LandmarksType._3D:
74-
self.depth_prediciton_net = torch.jit.load(load_file_from_url(models_urls['depth']))
85+
self.depth_prediciton_net = torch.jit.load(
86+
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))
7587

7688
self.depth_prediciton_net.to(device)
7789
self.depth_prediciton_net.eval()

0 commit comments

Comments
 (0)