-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
70 lines (60 loc) · 2.18 KB
/
__init__.py
File metadata and controls
70 lines (60 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Implementation/adapter for "ViT-FIQA" based on the corresponding repository, i.e.:
<https://github.com/atzoriandrea/ViT-FIQA-Assessing-Face-Image-Quality-using-Vision-Transformers/tree/master>
"""
# - # Standard imports:
from pathlib import Path
# - # External imports:
# None
# - # Toolkit imports:
import fiqat
# - # Local imports:
from .QualityModel import QualityModel
class VitFiqaAdapter:
def __init__(self, model_path: Path, color_channel: str = 'RGB'):
self.model_type = model_path.stem.replace(' ', '')
assert self.model_type in {'ViT-FIQA(T)', 'ViT-FIQA(C)'}, self.model_type
if self.model_type == 'ViT-FIQA(T)':
self.model_mode = 'token'
else:
assert self.model_type == 'ViT-FIQA(C)', self.model_type
self.model_mode = 'crfiqa'
# - #
self.model_path = model_path
self.model_epoch = None
self.gpu_id = None # 0
self.backbone = 'vit_FC'
self.color_channel = color_channel
assert color_channel in {'BGR', 'RGB'}, color_channel
self.batch_size = 16
self.process_batch_size = self.batch_size * 4
# - #
self.quality_model = None
def load_model(self):
if self.quality_model is not None:
return
self.quality_model = QualityModel(
model_prefix=self.model_path,
model_epoch=self.model_epoch,
gpu_id=self.gpu_id,
backbone=self.backbone,
model_mode=self.model_mode,
)
def _process(self, image_path_list: list):
self.load_model()
embedding_batch, quality_batch = self.quality_model.get_batch_feature(
image_path_list=image_path_list,
batch_size=self.batch_size,
color=self.color_channel,
)
return embedding_batch, quality_batch
def process(self, image_path_list: list, output_with_embedding: bool = False):
for sub_list in fiqat.iterate_as_batches(image_path_list, self.process_batch_size):
embedding_batch, quality_batch = self._process(sub_list)
for embedding, quality in zip(embedding_batch, quality_batch):
assert quality.shape == (1,), quality.shape
quality_score = quality[0]
if output_with_embedding:
yield embedding, quality_score
else:
yield quality_score