@@ -50,11 +50,12 @@ class NetworkSize(IntEnum):
5050
5151class FaceAlignment :
5252 def __init__ (self , landmarks_type , network_size = NetworkSize .LARGE ,
53- device = 'cuda' , flip_input = False , face_detector = 'sfd' , face_detector_kwargs = None , verbose = False ):
53+ device = 'cuda' , dtype = torch . float32 , flip_input = False , face_detector = 'sfd' , face_detector_kwargs = None , verbose = False ):
5454 self .device = device
5555 self .flip_input = flip_input
5656 self .landmarks_type = landmarks_type
5757 self .verbose = verbose
58+ self .dtype = dtype
5859
5960 if version .parse (torch .__version__ ) < version .parse ('1.5.0' ):
6061 raise ImportError (f'Unsupported pytorch version detected. Minimum supported version of pytorch: 1.5.0\
@@ -84,15 +85,15 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
8485 self .face_alignment_net = torch .jit .load (
8586 load_file_from_url (models_urls .get (pytorch_version , default_model_urls )[network_name ]))
8687
87- self .face_alignment_net .to (device )
88+ self .face_alignment_net .to (device , dtype = dtype )
8889 self .face_alignment_net .eval ()
8990
9091 # Initialiase the depth prediciton network
9192 if landmarks_type == LandmarksType .THREE_D :
9293 self .depth_prediciton_net = torch .jit .load (
9394 load_file_from_url (models_urls .get (pytorch_version , default_model_urls )['depth' ]))
9495
95- self .depth_prediciton_net .to (device )
96+ self .depth_prediciton_net .to (device , dtype = dtype )
9697 self .depth_prediciton_net .eval ()
9798
9899 def get_landmarks (self , image_or_path , detected_faces = None , return_bboxes = False , return_landmark_score = False ):
@@ -159,13 +160,13 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
159160 inp = torch .from_numpy (inp .transpose (
160161 (2 , 0 , 1 ))).float ()
161162
162- inp = inp .to (self .device )
163+ inp = inp .to (self .device , dtype = self . dtype )
163164 inp .div_ (255.0 ).unsqueeze_ (0 )
164165
165166 out = self .face_alignment_net (inp ).detach ()
166167 if self .flip_input :
167168 out += flip (self .face_alignment_net (flip (inp )).detach (), is_label = True )
168- out = out .cpu ( ).numpy ()
169+ out = out .to ( device = 'cpu' , dtype = torch . float32 ).numpy ()
169170
170171 pts , pts_img , scores = get_preds_fromhm (out , center .numpy (), scale )
171172 pts , pts_img = torch .from_numpy (pts ), torch .from_numpy (pts_img )
@@ -181,9 +182,9 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
181182 heatmaps = torch .from_numpy (
182183 heatmaps ).unsqueeze_ (0 )
183184
184- heatmaps = heatmaps .to (self .device )
185+ heatmaps = heatmaps .to (self .device , dtype = self . dtype )
185186 depth_pred = self .depth_prediciton_net (
186- torch .cat ((inp , heatmaps ), 1 )).data .cpu ().view (68 , 1 )
187+ torch .cat ((inp , heatmaps ), 1 )).data .cpu ().view (68 , 1 ). to ( dtype = torch . float32 )
187188 pts_img = torch .cat (
188189 (pts_img , depth_pred * (1.0 / (256.0 / (200.0 * scale )))), 1 )
189190
0 commit comments