Skip to content

Commit 7a56439

Browse files
👽️ API changes in PyTorch 1.11 (#25)
1 parent 48ec8c4 commit 7a56439

File tree

4 files changed

+15
-4
lines changed

4 files changed

+15
-4
lines changed

piqa/ssim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
244244
assert_type(
245245
input, target,
246246
device=self.kernel.device,
247-
dim_range=(3, -1),
247+
dim_range=(3, 5),
248248
n_channels=self.kernel.size(0),
249249
value_range=(0., self.value_range),
250250
)

piqa/utils/color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def color_conv(
2929
weight: A weight kernel, :math:`(C', C)`.
3030
"""
3131

32-
return F.conv1d(x, weight.view(weight.shape + (1,) * spatial(x)))
32+
return F.linear(x.transpose(1, -1), weight).transpose(1, -1)
3333

3434

3535
RGB_TO_YIQ = torch.tensor([

piqa/utils/functional.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,18 @@ def channel_conv(
3636
[144., 153., 162.]]]])
3737
"""
3838

39-
return F.conv1d(x, kernel, padding=padding, groups=x.size(1))
39+
D = len(kernel.shape) - 2
40+
41+
assert D <= 3, "PyTorch only supports 1D, 2D or 3D convolutions."
42+
43+
if D == 3:
44+
return F.conv3d(x, kernel, padding=padding, groups=x.size(-4))
45+
elif D == 2:
46+
return F.conv2d(x, kernel, padding=padding, groups=x.size(-3))
47+
elif D == 1:
48+
return F.conv1d(x, kernel, padding=padding, groups=x.size(-2))
49+
else:
50+
return F.linear(x, kernel.expand(x.size(-1)))
4051

4152

4253
def channel_convs(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setuptools.setup(
1212
name='piqa',
13-
version='1.2.1',
13+
version='1.2.2',
1414
packages=setuptools.find_packages(),
1515
description='PyTorch Image Quality Assessment',
1616
keywords='image quality processing metrics torch vision',

0 commit comments

Comments
 (0)