Skip to content

Commit 48ec8c4

Browse files
🩹 Fix performance issues due to torch.linalg.norm
1 parent 61cbb4c commit 48ec8c4

File tree

8 files changed

+60
-22
lines changed

8 files changed

+60
-22
lines changed

‎piqa/fsim.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
filter_grid,
2929
log_gabor,
3030
channel_conv,
31+
l2_norm,
3132
)
3233

3334

@@ -87,8 +88,8 @@ def fsim(
8788
# Gradient magnitude similarity
8889
pad = kernel.size(-1) // 2
8990

90-
g_x = torch.linalg.norm(channel_conv(y_x, kernel, padding=pad), dim=1)
91-
g_y = torch.linalg.norm(channel_conv(y_y, kernel, padding=pad), dim=1)
91+
g_x = l2_norm(channel_conv(y_x, kernel, padding=pad), dims=[1])
92+
g_y = l2_norm(channel_conv(y_y, kernel, padding=pad), dims=[1])
9293

9394
s_g = (2 * g_x * g_y + t2) / (g_x ** 2 + g_y ** 2 + t2)
9495

‎piqa/gmsd.py‎

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
prewitt_kernel,
2525
gradient_kernel,
2626
channel_conv,
27+
l2_norm,
2728
)
2829

2930

@@ -75,8 +76,8 @@ def gmsd(
7576
# Gradient magnitude
7677
pad = kernel.size(-1) // 2
7778

78-
gm_x = torch.linalg.norm(channel_conv(x, kernel, padding=pad), dim=1)
79-
gm_y = torch.linalg.norm(channel_conv(y, kernel, padding=pad), dim=1)
79+
gm_x = l2_norm(channel_conv(x, kernel, padding=pad), dims=[1])
80+
gm_y = l2_norm(channel_conv(y, kernel, padding=pad), dims=[1])
8081

8182
gm_xy = gm_x * gm_y
8283

@@ -109,7 +110,8 @@ def ms_gmsd(
109110
without color space conversion.
110111
111112
.. math::
112-
\text{MS-GMSD}(x, y) = \sum^{M}_{i = 1} w_i \text{GMSD}(x^i, y^i)
113+
\text{MS-GMSD}(x, y) =
114+
\sqrt{\sum^{M}_{i = 1} w_i \text{GMSD}(x^i, y^i) ** 2}
113115
114116
where :math:`x^i` and :math:`y^i` are obtained by downsampling
115117
the initial tensors by a factor :math:`2^{i - 1}`.
@@ -150,8 +152,8 @@ def ms_gmsd(
150152
c=c, alpha=alpha,
151153
))
152154

153-
msgmsd = torch.stack(gmsds, dim=-1) ** 2
154-
msgmsd = torch.sqrt((msgmsd * weights).sum(dim=-1))
155+
msgmsd = weights * torch.stack(gmsds, dim=-1) ** 2
156+
msgmsd = msgmsd.sum(dim=-1).sqrt()
155157

156158
return msgmsd
157159

‎piqa/lpips.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Dict, List
2323

2424
from .utils import _jit, assert_type, reduce_tensor
25+
from .utils.functional import l2_norm
2526

2627

2728
ORIGIN: str = 'https://github.com/richzhang/PerceptualSimilarity'
@@ -209,8 +210,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
209210
residuals = []
210211

211212
for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)):
212-
fx = fx / torch.linalg.norm(fx, dim=1, keepdim=True)
213-
fy = fy / torch.linalg.norm(fy, dim=1, keepdim=True)
213+
fx = fx / l2_norm(fx, dims=[1], keepdim=True)
214+
fy = fy / l2_norm(fy, dims=[1], keepdim=True)
214215

215216
mse = ((fx - fy) ** 2).mean(dim=(-1, -2), keepdim=True)
216217
residuals.append(lin(mse).flatten())

‎piqa/mdsi.py‎

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
prewitt_kernel,
2323
gradient_kernel,
2424
channel_conv,
25+
l2_norm,
2526
)
2627

2728

@@ -78,12 +79,9 @@ def mdsi(
7879
# Gradient magnitude
7980
pad = kernel.size(-1) // 2
8081

81-
gm_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1)
82-
gm_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1)
83-
gm_avg = torch.linalg.norm(
84-
channel_conv((l_x + l_y) / 2., kernel, padding=pad),
85-
dim=1,
86-
)
82+
gm_x = l2_norm(channel_conv(l_x, kernel, padding=pad), dims=[1])
83+
gm_y = l2_norm(channel_conv(l_y, kernel, padding=pad), dims=[1])
84+
gm_avg = l2_norm(channel_conv(l_x + l_y, kernel, padding=pad), dims=[1]) / 2
8785

8886
gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2
8987

‎piqa/ssim.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def ms_ssim(
187187

188188
css.append(torch.relu(cs) if i + 1 < m else torch.relu(ss))
189189

190-
msss = torch.stack(css, dim=-1)
191-
msss = (msss ** weights).prod(dim=-1)
190+
msss = torch.stack(css, dim=-1) ** weights
191+
msss = msss.prod(dim=-1).mean(dim=-1)
192192

193-
return msss.mean(dim=-1)
193+
return msss
194194

195195

196196
class SSIM(nn.Module):

‎piqa/utils/functional.py‎

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,38 @@ def log_gabor(f: Tensor, f_0: float, sigma_f: float) -> Tensor:
335335
"""
336336

337337
return torch.exp(- (f / f_0).log() ** 2 / (2 * sigma_f ** 2))
338+
339+
340+
def l2_norm(
341+
x: torch.Tensor,
342+
dims: List[int],
343+
keepdim: bool = False,
344+
) -> torch.Tensor:
345+
r"""Returns the :math:`L_2` norm of :math:`x`.
346+
347+
.. math:
348+
L_2(x) = \left\| x \right\|_2 = \sqrt{\sum_i x^2_i}
349+
350+
Args:
351+
x: A tensor, :math:`(*,)`.
352+
dims: The dimensions along which to calculate the norm.
353+
keepdim: Whether the output tensor has `dims` retained or not.
354+
355+
Wikipedia:
356+
https://en.wikipedia.org/wiki/Norm_(mathematics)
357+
358+
Example:
359+
>>> x = torch.arange(9).float().view(3, 3)
360+
>>> x
361+
tensor([[0., 1., 2.],
362+
[3., 4., 5.],
363+
[6., 7., 8.]])
364+
>>> l2_norm(x, dims=[0])
365+
tensor([6.7082, 8.1240, 9.6437])
366+
"""
367+
368+
x = x ** 2
369+
x = x.sum(dim=dims, keepdim=keepdim)
370+
x = x.sqrt()
371+
372+
return x

‎piqa/vsi.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
filter_grid,
3131
log_gabor,
3232
channel_conv,
33+
l2_norm,
3334
)
3435

3536

@@ -87,8 +88,8 @@ def vsi(
8788
# Gradient magnitude similarity
8889
pad = kernel.size(-1) // 2
8990

90-
g_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1)
91-
g_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1)
91+
g_x = l2_norm(channel_conv(l_x, kernel, padding=pad), dims=[1])
92+
g_y = l2_norm(channel_conv(l_y, kernel, padding=pad), dims=[1])
9293

9394
s_g = (2 * g_x * g_y + c2) / (g_x ** 2 + g_y ** 2 + c2)
9495

@@ -171,7 +172,7 @@ def sdsp(
171172
x_f = fft.ifft2(fft.fft2(x_lab) * filtr)
172173
x_f = cx.real(torch.view_as_real(x_f))
173174

174-
s_f = torch.linalg.norm(x_f, dim=1)
175+
s_f = l2_norm(x_f, dims=[1])
175176

176177
# Color prior
177178
x_ab = x_lab[:, 1:]

‎setup.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setuptools.setup(
1212
name='piqa',
13-
version='1.2.0',
13+
version='1.2.1',
1414
packages=setuptools.find_packages(),
1515
description='PyTorch Image Quality Assessment',
1616
keywords='image quality processing metrics torch vision',

0 commit comments

Comments
 (0)