Skip to content

Commit f744007

Browse files
🔥 Drop utils.tensor_norm in favor of torch.linalg.norm
1 parent 4dcb651 commit f744007

File tree

7 files changed

+18
-137
lines changed

7 files changed

+18
-137
lines changed

‎piqa/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
specific image quality assessement metric.
66
"""
77

8-
__version__ = '1.1.2'
8+
__version__ = '1.1.3'
99

1010
from .tv import TV
1111
from .psnr import PSNR

‎piqa/gmsd.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
prewitt_kernel,
2626
gradient_kernel,
2727
channel_conv,
28-
tensor_norm,
2928
)
3029

3130

@@ -77,8 +76,8 @@ def gmsd(
7776
# Gradient magnitude
7877
pad = kernel.size(-1) // 2
7978

80-
gm_x = tensor_norm(channel_conv(x, kernel, padding=pad), dim=[1])
81-
gm_y = tensor_norm(channel_conv(y, kernel, padding=pad), dim=[1])
79+
gm_x = torch.linalg.norm(channel_conv(x, kernel, padding=pad), dim=1)
80+
gm_y = torch.linalg.norm(channel_conv(y, kernel, padding=pad), dim=1)
8281

8382
gm_xy = gm_x * gm_y
8483

‎piqa/lpips.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch.hub as hub
2020

2121
from piqa.utils import _jit, _assert_type, _reduce
22-
from piqa.utils.functional import normalize_tensor
2322

2423
from typing import Dict, List
2524

@@ -225,8 +224,8 @@ def forward(
225224
residuals = []
226225

227226
for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)):
228-
fx = normalize_tensor(fx, dim=[1], norm='L2')
229-
fy = normalize_tensor(fy, dim=[1], norm='L2')
227+
fx = fx / torch.linalg.norm(fx, dim=1, keepdim=True)
228+
fy = fy / torch.linalg.norm(fy, dim=1, keepdim=True)
230229

231230
mse = ((fx - fy) ** 2).mean(dim=(-1, -2), keepdim=True)
232231
residuals.append(lin(mse).flatten())

‎piqa/mdsi.py‎

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
prewitt_kernel,
2020
gradient_kernel,
2121
channel_conv,
22-
tensor_norm,
2322
)
2423

2524
import piqa.utils.complex as cx
@@ -77,11 +76,11 @@ def mdsi(
7776
# Gradient magnitude
7877
pad = kernel.size(-1) // 2
7978

80-
gm_x = tensor_norm(channel_conv(l_x, kernel, padding=pad), dim=[1])
81-
gm_y = tensor_norm(channel_conv(l_y, kernel, padding=pad), dim=[1])
82-
gm_avg = tensor_norm(
79+
gm_x = torch.linalg.norm(channel_conv(l_x, kernel, padding=pad), dim=1)
80+
gm_y = torch.linalg.norm(channel_conv(l_y, kernel, padding=pad), dim=1)
81+
gm_avg = torch.linalg.norm(
8382
channel_conv((l_x + l_y) / 2., kernel, padding=pad),
84-
dim=[1],
83+
dim=1,
8584
)
8685

8786
gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2

‎piqa/utils/__init__.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@ def _assert_type(
8585
)
8686

8787

88+
@_jit
8889
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
89-
r"""Returns a reducing module.
90+
r"""Returns the reduction of \(x\).
9091
9192
Args:
93+
x: A tensor, \((*,)\).
9294
reduction: Specifies the reduction type:
9395
`'none'` | `'mean'` | `'sum'`.
9496

‎piqa/utils/complex.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def mod(x: torch.Tensor, squared: bool = False) -> torch.Tensor:
7070
tensor([2.0000, 1.0000])
7171
"""
7272

73-
x = (x ** 2).sum(dim=-1)
73+
x = x.square().sum(dim=-1)
7474

7575
if not squared:
7676
x = torch.sqrt(x)

‎piqa/utils/functional.py‎

Lines changed: 5 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,13 @@ def gaussian_kernel(
8282
) -> torch.Tensor:
8383
r"""Returns the 1-dimensional Gaussian kernel of size \(K\).
8484
85-
$$ G(x) = \frac{1}{\sum_{y = 1}^{K} G(y)} \exp
85+
$$ G(x) = \gamma \exp
8686
\left(\frac{(x - \mu)^2}{2 \sigma^2}\right) $$
8787
88-
where \(x \in [1; K]\) is a position in the kernel
88+
where \(\gamma\) is such that
89+
90+
$$ \sum_{x = 1}^{K} G(x) = 1 $$
91+
8992
and \(\mu = \frac{1 + K}{2}\).
9093
9194
Args:
@@ -263,124 +266,3 @@ def gradient_kernel(kernel: torch.Tensor) -> torch.Tensor:
263266
"""
264267

265268
return torch.stack([kernel, kernel.t()]).unsqueeze(1)
266-
267-
268-
def tensor_norm(
269-
x: torch.Tensor,
270-
dim: List[int], # Union[int, Tuple[int, ...]] = ()
271-
keepdim: bool = False,
272-
norm: str = 'L2',
273-
) -> torch.Tensor:
274-
r"""Returns the norm of \(x\).
275-
276-
$$ L_1(x) = \left\| x \right\|_1 = \sum_i \left| x_i \right| $$
277-
278-
$$ L_2(x) = \left\| x \right\|_2 = \left( \sum_i x^2_i \right)^\frac{1}{2} $$
279-
280-
Args:
281-
x: A tensor, \((*,)\).
282-
dim: The dimension(s) along which to calculate the norm.
283-
keepdim: Whether the output tensor has `dim` retained or not.
284-
norm: Specifies the norm funcion to apply:
285-
`'L1'` | `'L2'` | `'L2_squared'`.
286-
287-
Wikipedia:
288-
https://en.wikipedia.org/wiki/Norm_(mathematics)
289-
290-
Example:
291-
>>> x = torch.arange(9).float().view(3, 3)
292-
>>> x
293-
tensor([[0., 1., 2.],
294-
[3., 4., 5.],
295-
[6., 7., 8.]])
296-
>>> tensor_norm(x, dim=0)
297-
tensor([6.7082, 8.1240, 9.6437])
298-
"""
299-
300-
if norm == 'L1':
301-
x = x.abs()
302-
else: # norm in ['L2', 'L2_squared']
303-
x = x ** 2
304-
305-
x = x.sum(dim=dim, keepdim=keepdim)
306-
307-
if norm == 'L2':
308-
x = x.sqrt()
309-
310-
return x
311-
312-
313-
def normalize_tensor(
314-
x: torch.Tensor,
315-
dim: List[int], # Union[int, Tuple[int, ...]] = ()
316-
norm: str = 'L2',
317-
epsilon: float = 1e-8,
318-
) -> torch.Tensor:
319-
r"""Returns \(x\) normalized.
320-
321-
$$ \hat{x} = \frac{x}{\left\|x\right\|} $$
322-
323-
Args:
324-
x: A tensor, \((*,)\).
325-
dim: The dimension(s) along which to normalize.
326-
norm: Specifies the norm funcion to use:
327-
`'L1'` | `'L2'` | `'L2_squared'`.
328-
epsilon: A numerical stability term.
329-
330-
Returns:
331-
The normalized tensor, \((*,)\).
332-
333-
Example:
334-
>>> x = torch.arange(9, dtype=torch.float).view(3, 3)
335-
>>> x
336-
tensor([[0., 1., 2.],
337-
[3., 4., 5.],
338-
[6., 7., 8.]])
339-
>>> normalize_tensor(x, dim=0)
340-
tensor([[0.0000, 0.1231, 0.2074],
341-
[0.4472, 0.4924, 0.5185],
342-
[0.8944, 0.8616, 0.8296]])
343-
"""
344-
345-
norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm)
346-
347-
return x / (norm + epsilon)
348-
349-
350-
def unravel_index(
351-
indices: torch.LongTensor,
352-
shape: List[int],
353-
) -> torch.LongTensor:
354-
r"""Converts flat indices into unraveled coordinates in a target shape.
355-
356-
This is a `torch` implementation of `numpy.unravel_index`.
357-
358-
Args:
359-
indices: A tensor of (flat) indices, \((*, N)\).
360-
shape: The targeted shape, \((D,)\).
361-
362-
Returns:
363-
The unraveled coordinates, \((*, N, D)\).
364-
365-
Example:
366-
>>> unravel_index(torch.arange(9), shape=(3, 3))
367-
tensor([[0, 0],
368-
[0, 1],
369-
[0, 2],
370-
[1, 0],
371-
[1, 1],
372-
[1, 2],
373-
[2, 0],
374-
[2, 1],
375-
[2, 2]])
376-
"""
377-
378-
coord = []
379-
380-
for dim in reversed(shape):
381-
coord.append(indices % dim)
382-
indices = indices // dim
383-
384-
coord = torch.stack(coord[::-1], dim=-1)
385-
386-
return coord

0 commit comments

Comments
 (0)