Skip to content

Commit c047411

Browse files
committed
expose the dtype
1 parent c19affb commit c047411

File tree

3 files changed

+19
-39
lines changed

3 files changed

+19
-39
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,14 @@ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, face_detec
7373
In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag:
7474

7575
```python
76+
import torch
7677
import face_alignment
7778

78-
# cuda for CUDA
79+
# cuda for CUDA, mps for Apple M1/2 GPUs.
7980
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu')
81+
82+
# running using lower precision
83+
fa = fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, dtype=torch.bfloat16, device='cuda')
8084
```
8185

8286
Please also see the ``examples`` folder
@@ -85,10 +89,10 @@ Please also see the ``examples`` folder
8589

8690
```python
8791

88-
# dlib
92+
# dlib (fast, may miss faces)
8993
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='dlib')
9094

91-
# SFD
95+
# SFD (likely best results, but slowest)
9296
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='sfd')
9397

9498
# Blazeface (front camera model)

face_alignment/api.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ class NetworkSize(IntEnum):
5050

5151
class FaceAlignment:
5252
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
53-
device='cuda', flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
53+
device='cuda', dtype=torch.float32, flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
5454
self.device = device
5555
self.flip_input = flip_input
5656
self.landmarks_type = landmarks_type
5757
self.verbose = verbose
58+
self.dtype = dtype
5859

5960
if version.parse(torch.__version__) < version.parse('1.5.0'):
6061
raise ImportError(f'Unsupported pytorch version detected. Minimum supported version of pytorch: 1.5.0\
@@ -84,15 +85,15 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
8485
self.face_alignment_net = torch.jit.load(
8586
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name]))
8687

87-
self.face_alignment_net.to(device)
88+
self.face_alignment_net.to(device, dtype=dtype)
8889
self.face_alignment_net.eval()
8990

9091
# Initialiase the depth prediciton network
9192
if landmarks_type == LandmarksType.THREE_D:
9293
self.depth_prediciton_net = torch.jit.load(
9394
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))
9495

95-
self.depth_prediciton_net.to(device)
96+
self.depth_prediciton_net.to(device, dtype=dtype)
9697
self.depth_prediciton_net.eval()
9798

9899
def get_landmarks(self, image_or_path, detected_faces=None, return_bboxes=False, return_landmark_score=False):
@@ -159,13 +160,13 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
159160
inp = torch.from_numpy(inp.transpose(
160161
(2, 0, 1))).float()
161162

162-
inp = inp.to(self.device)
163+
inp = inp.to(self.device, dtype=self.dtype)
163164
inp.div_(255.0).unsqueeze_(0)
164165

165166
out = self.face_alignment_net(inp).detach()
166167
if self.flip_input:
167168
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
168-
out = out.cpu().numpy()
169+
out = out.to(device='cpu', dtype=torch.float32).numpy()
169170

170171
pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
171172
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
@@ -181,9 +182,9 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
181182
heatmaps = torch.from_numpy(
182183
heatmaps).unsqueeze_(0)
183184

184-
heatmaps = heatmaps.to(self.device)
185+
heatmaps = heatmaps.to(self.device, dtype=self.dtype)
185186
depth_pred = self.depth_prediciton_net(
186-
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
187+
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1).to(dtype=torch.float32)
187188
pts_img = torch.cat(
188189
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
189190

face_alignment/detection/blazeface/net_blazeface.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -114,38 +114,13 @@ def _define_back_model_layers(self):
114114
self.backbone = nn.Sequential(
115115
nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True),
116116
nn.ReLU(inplace=True),
117-
118-
BlazeBlock(24, 24),
119-
BlazeBlock(24, 24),
120-
BlazeBlock(24, 24),
121-
BlazeBlock(24, 24),
122-
BlazeBlock(24, 24),
123-
BlazeBlock(24, 24),
124-
BlazeBlock(24, 24),
117+
*[BlazeBlock(24, 24) for _ in range(7)],
125118
BlazeBlock(24, 24, stride=2),
126-
BlazeBlock(24, 24),
127-
BlazeBlock(24, 24),
128-
BlazeBlock(24, 24),
129-
BlazeBlock(24, 24),
130-
BlazeBlock(24, 24),
131-
BlazeBlock(24, 24),
132-
BlazeBlock(24, 24),
119+
*[BlazeBlock(24, 24) for _ in range(7)],
133120
BlazeBlock(24, 48, stride=2),
134-
BlazeBlock(48, 48),
135-
BlazeBlock(48, 48),
136-
BlazeBlock(48, 48),
137-
BlazeBlock(48, 48),
138-
BlazeBlock(48, 48),
139-
BlazeBlock(48, 48),
140-
BlazeBlock(48, 48),
121+
*[BlazeBlock(48, 48) for _ in range(7)],
141122
BlazeBlock(48, 96, stride=2),
142-
BlazeBlock(96, 96),
143-
BlazeBlock(96, 96),
144-
BlazeBlock(96, 96),
145-
BlazeBlock(96, 96),
146-
BlazeBlock(96, 96),
147-
BlazeBlock(96, 96),
148-
BlazeBlock(96, 96),
123+
*[BlazeBlock(96, 96) for _ in range(7)],
149124
)
150125
self.final = FinalBlazeBlock(96)
151126
self.classifier_8 = nn.Conv2d(96, 2, 1, bias=True)

0 commit comments

Comments
 (0)