Skip to content

Commit d098f4d

Browse files
🐛 Fix JITing of complex submodule
The `complex` function confused the compiler.
1 parent af274e6 commit d098f4d

File tree

6 files changed

+12
-15
lines changed

6 files changed

+12
-15
lines changed

piqa/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
specific image quality assessement metric.
66
"""
77

8-
__version__ = '1.1.6'
8+
__version__ = '1.1.7'
99

1010
from .tv import TV
1111
from .psnr import PSNR

piqa/fsim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def fsim(
105105
s_q = (2 * q_x * q_y + t4) / (q_x ** 2 + q_y ** 2 + t4)
106106

107107
s_iq = s_i * s_q
108-
s_iq = cx.complex(s_iq, torch.zeros_like(s_iq))
108+
s_iq = cx.complx(s_iq, torch.zeros_like(s_iq))
109109
s_iq_lambda = cx.real(cx.pow(s_iq, lmbda))
110110

111111
s_l = s_l * s_iq_lambda

piqa/mdsi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def mdsi(
101101
cs = cs_num / cs_den
102102

103103
# Gradient-chromaticity similarity
104-
gs = cx.complex(gs, torch.zeros_like(gs))
105-
cs = cx.complex(cs, torch.zeros_like(cs))
104+
gs = cx.complx(gs, torch.zeros_like(gs))
105+
cs = cx.complx(cs, torch.zeros_like(cs))
106106

107107
if combination == 'prod':
108108
gcs = cx.prod(cx.pow(gs, gamma), cx.pow(cs, beta))

piqa/utils/complex.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def complex(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
7+
def complx(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
88
r"""Returns a complex tensor with its real part equal to \(\Re\) and
99
its imaginary part equal to \(\Im\).
1010
@@ -20,7 +20,7 @@ def complex(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
2020
Example:
2121
>>> x = torch.tensor([2., 0.7071])
2222
>>> y = torch.tensor([0., 0.7071])
23-
>>> complex(x, y)
23+
>>> complx(x, y)
2424
tensor([[2.0000, 0.0000],
2525
[0.7071, 0.7071]])
2626
"""
@@ -103,7 +103,7 @@ def turn(x: torch.Tensor) -> torch.Tensor:
103103
[-0.7071, 0.7071]])
104104
"""
105105

106-
return complex(-imag(x), real(x))
106+
return complx(-imag(x), real(x))
107107

108108

109109
def polar(r: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
@@ -127,7 +127,7 @@ def polar(r: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
127127
[0.7071, 0.7071]])
128128
"""
129129

130-
return complex(r * torch.cos(phi), r * torch.sin(phi))
130+
return complx(r * torch.cos(phi), r * torch.sin(phi))
131131

132132

133133
def mod(x: torch.Tensor, squared: bool = False) -> torch.Tensor:
@@ -200,7 +200,7 @@ def prod(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
200200
x_r, x_i = x[..., 0], x[..., 1]
201201
y_r, y_i = y[..., 0], y[..., 1]
202202

203-
return complex(x_r * y_r - x_i * y_i, x_i * y_r + x_r * y_i)
203+
return complx(x_r * y_r - x_i * y_i, x_i * y_r + x_r * y_i)
204204

205205

206206
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

piqa/vsi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def vsi(
101101
s_c = (2 * mn_x * mn_y + c3) / (mn_x ** 2 + mn_y ** 2 + c3)
102102
s_c = s_c.prod(dim=1)
103103

104-
s_c = cx.complex(s_c, torch.zeros_like(s_c))
104+
s_c = cx.complx(s_c, torch.zeros_like(s_c))
105105
s_c_beta = cx.real(cx.pow(s_c, beta))
106106

107107
s_vs = s_vs * s_c_beta

tests/benchmark.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
'PSNR': (2, {
4545
'sk.psnr-np': sk.peak_signal_noise_ratio,
4646
'piq.psnr': piq.psnr,
47-
'kornia.PSNR': kornia.PSNRLoss(max_val=1.),
47+
'kornia.PSNR': kornia.PSNRLoss(1.),
4848
'piqa.PSNR': piqa.PSNR(),
4949
}),
5050
'SSIM': (2, {
@@ -55,10 +55,7 @@
5555
gaussian_weights=True,
5656
),
5757
'piq.ssim': lambda x, y: piq.ssim(x, y, downsample=False),
58-
'kornia.SSIM-halfloss': kornia.SSIM(
59-
window_size=11,
60-
reduction='mean',
61-
),
58+
'kornia.SSIM-halfloss': kornia.SSIMLoss(11),
6259
'IQA.SSIM-loss': IQA.SSIM(),
6360
'vainf.SSIM': vainf.SSIM(data_range=1.),
6461
'piqa.SSIM': piqa.SSIM(),

0 commit comments

Comments
 (0)