forked from zllrunning/face-parsing.PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
97 lines (77 loc) · 3.35 KB
/
evaluate.py
File metadata and controls
97 lines (77 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import logging
from logger import setup_logger
from model import BiSeNet
import torch
import os
import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros(
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
# print(vis_parsing_anno_color.shape, vis_im.shape)
# cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
vis_im = vis_parsing_anno_color
# Save result or not
if save_im:
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
# return vis_im
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
if not os.path.exists(respth):
os.makedirs(respth)
logging.info("Initiating Model")
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = osp.join('res/cp', cp)
logging.info("Start Loading Model")
net.load_state_dict(torch.load(save_pth))
logging.info("Finished Model Loading")
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
with torch.no_grad():
img_list = os.listdir(dspth)
img_list.sort()
for image_path in img_list:
if not os.path.exists(osp.join(respth, image_path)):
logging.info(image_path)
img = Image.open(osp.join(dspth, image_path))
image = img.resize((512, 512), Image.BILINEAR)
img = to_tensor(image)
img = torch.unsqueeze(img, 0)
img = img.cuda()
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
vis_parsing_maps(image, parsing, stride=1, save_im=True,
save_path=osp.join(respth, image_path))
if __name__ == "__main__":
setup_logger('./res')
evaluate(dspth='/mnt/lustre/wangzhibo/ffhq/images1024x1024',
respth='/mnt/lustre/wangzhibo/ffhq/images_mask1024x1024', cp='79999_iter.pth')