Skip to content

Commit fcc306a

Browse files
committed
fix type
1 parent 7e6602a commit fcc306a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

face_alignment/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None):
113113

114114
landmarks = []
115115
for i, d in enumerate(detected_faces):
116-
center = torch.FloatTensor(
116+
center = torch.tensor(
117117
[d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
118118
center[1] = center[1] - (d[3] - d[1]) * 0.12
119119
scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
@@ -130,9 +130,9 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None):
130130
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
131131
out = out.cpu().numpy()
132132

133-
pts, pts_img = get_preds_fromhm(out, center, scale)
134-
pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
133+
pts, pts_img = get_preds_fromhm(out, center.numpy(), scale)
135134
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
135+
pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
136136

137137
if self.landmarks_type == LandmarksType._3D:
138138
heatmaps = np.zeros((68, 256, 256), dtype=np.float32)

0 commit comments

Comments
 (0)