-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathQualityModel.py
More file actions
83 lines (74 loc) · 2.37 KB
/
QualityModel.py
File metadata and controls
83 lines (74 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Modified version of code from
<https://github.com/atzoriandrea/ViT-FIQA-Assessing-Face-Image-Quality-using-Vision-Transformers/tree/master>.
"""
import os
# import sys
# sys.path.append('../')
# print(os.getcwd())
import torch
# import numpy as np
from .FaceModel import FaceModel
# from backbones.iresnet_qs import iresnet100, iresnet50
from .backbones.vit_qs import VisionTransformer
class QualityModel(FaceModel):
def __init__(
self,
model_prefix,
model_epoch,
gpu_id,
backbone,
model_mode: str, # NOTE Added (not original).
):
assert model_mode in {'token', 'crfiqa'}, model_mode
self.model_mode = model_mode
super(QualityModel, self).__init__(model_prefix, model_epoch, gpu_id, backbone)
def _get_model(
self,
ctx,
image_size,
prefix,
epoch,
layer,
backbone,
):
# print(backbone)
# if (backbone=="iresnet50" or backbone=="iresnet50_FC"):
# backbones = iresnet50(num_features=512, qs=1, use_se=False).to(f"cuda:{ctx}")
# elif (backbone=="iresnet100"):
# backbones = iresnet100(num_features=512, qs=1, use_se=False).to(f"cuda:{ctx}")
if backbone == "vit_FC":
backbones = VisionTransformer(
img_size=112,
patch_size=9,
num_classes=512,
embed_dim=512,
depth=12,
num_heads=8,
drop_path_rate=0.1,
norm_layer="ln",
mask_ratio=0.1,
mode=self.model_mode,
)
else:
raise NotImplementedError("Error. Backbone not found!")
if backbone == "vit_FC" or backbone == "iresnet50_FC":
dict_checkpoint = torch.load(
os.path.join(prefix, "model.pt"),
map_location=torch.device('cpu'),
)
# print(dict_checkpoint.keys)
# for key, value in dict_checkpoint.items() :
# print (key)
backbones.load_state_dict(dict_checkpoint)
else:
weight = torch.load(os.path.join(prefix, epoch + "backbone.pth"))
backbones.load_state_dict(weight)
model = torch.nn.DataParallel(backbones, device_ids=[ctx])
model.eval()
return model
@torch.no_grad()
def _getFeatureBlob(self, input_blob):
imgs = torch.Tensor(input_blob).cpu() # .cuda()
imgs.div_(255).sub_(0.5).div_(0.5)
feat, qs = self.model(imgs)
return feat.cpu().numpy(), qs.cpu().numpy() #* np.linalg.norm(feat.cpu().numpy())