-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmetrics.py
More file actions
115 lines (104 loc) · 3.73 KB
/
metrics.py
File metadata and controls
115 lines (104 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from skimage import metrics
import torch
import torch.hub
from lpips.lpips import LPIPS
import os
import os
import numpy as np
photometric = {
"mse": None,
"ssim": None,
"psnr": None,
"lpips": None
}
import pyiqa
import torch.nn.functional as F
short_edge = 384
NR_metrics = {}
def compute_img_metric_single(img,metric = "niqe"):
# metric:niqe,brisque
if metric not in NR_metrics.keys():
iqa_metric = pyiqa.create_metric(metric, device=torch.device("cuda"))
NR_metrics.update({metric:iqa_metric})
# resize
if metric == 'liqe_mix':
h,w = img.shape[2],img.shape[3]
if h < w:
new_h, new_w = short_edge, int(w * short_edge / h)
else:
new_h, new_w = int(h * short_edge / w), short_edge
img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False)
return NR_metrics[metric](img).item()
def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
metric="mse", margin=0, mask=None):
"""
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
"""
if metric not in photometric.keys():
raise RuntimeError(f"img_utils:: metric {metric} not recognized")
if photometric[metric] is None:
if metric == "mse":
photometric[metric] = metrics.mean_squared_error
elif metric == "ssim":
photometric[metric] = metrics.structural_similarity
elif metric == "psnr":
photometric[metric] = metrics.peak_signal_noise_ratio
elif metric == "lpips":
photometric[metric] = LPIPS().cpu()
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] == 1:
mask = mask.expand(-1, 3, -1, -1)
mask = mask.permute(0, 2, 3, 1).numpy()
batchsz, hei, wid, _ = mask.shape
if margin > 0:
marginh = int(hei * margin) + 1
marginw = int(wid * margin) + 1
mask = mask[:, marginh:hei - marginh, marginw:wid - marginw]
# convert from [0, 1] to [-1, 1]
im1t = (im1t * 2 - 1).clamp(-1, 1)
im2t = (im2t * 2 - 1).clamp(-1, 1)
if im1t.dim() == 3:
im1t = im1t.unsqueeze(0)
im2t = im2t.unsqueeze(0)
im1t = im1t.detach().cpu()
im2t = im2t.detach().cpu()
if im1t.shape[-1] == 3:
im1t = im1t.permute(0, 3, 1, 2)
im2t = im2t.permute(0, 3, 1, 2)
im1 = im1t.permute(0, 2, 3, 1).numpy()
im2 = im2t.permute(0, 2, 3, 1).numpy()
batchsz, hei, wid, _ = im1.shape
if margin > 0:
marginh = int(hei * margin) + 1
marginw = int(wid * margin) + 1
im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw]
im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw]
values = []
for i in range(batchsz):
if metric in ["mse", "psnr"]:
if mask is not None:
im1 = im1 * mask[i]
im2 = im2 * mask[i]
value = photometric[metric](
im1[i], im2[i]
)
if mask is not None:
hei, wid, _ = im1[i].shape
pixelnum = mask[i, ..., 0].sum()
value = value - 10 * np.log10(hei * wid / pixelnum)
elif metric in ["ssim"]:
value, ssimmap = photometric["ssim"](
im1[i], im2[i], channel_axis=-1, data_range=2, full=True
)
if mask is not None:
value = (ssimmap * mask[i]).sum() / mask[i].sum()
elif metric in ["lpips"]:
value = photometric[metric](
im1t[i:i + 1], im2t[i:i + 1]
)[0,0,0,0]
else:
raise NotImplementedError
values.append(value)
return sum(values) / len(values)