Skip to content

Commit ffc07e0

Browse files
🚸 Provide option to disable debugging (#17)
🚸 Improve type error messages
1 parent d098f4d commit ffc07e0

File tree

13 files changed

+68
-56
lines changed

13 files changed

+68
-56
lines changed

‎README.md‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,10 @@ If you need the absolute best performances, the assertions can be disabled with
129129
```bash
130130
python -O your_awesome_code_using_piqa.py
131131
```
132+
133+
Alternatively, you can disable PIQA's type assertions within your code with
134+
135+
```python
136+
from piqa.utils import set_debug
137+
set_debug(False)
138+
```

‎piqa/fsim.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.nn as nn
2222
import torch.nn.functional as F
2323

24-
from piqa.utils import _jit, _assert_type, _reduce
24+
from piqa.utils import _jit, assert_type, reduce_tensor
2525
from piqa.utils.color import ColorConv
2626
from piqa.utils.functional import (
2727
scharr_kernel,
@@ -308,7 +308,7 @@ def forward(
308308
r"""Defines the computation performed at every call.
309309
"""
310310

311-
_assert_type(
311+
assert_type(
312312
[input, target],
313313
device=self.kernel.device,
314314
dim_range=(4, 4),
@@ -339,4 +339,4 @@ def forward(
339339
# FSIM
340340
l = fsim(input, target, pc_input, pc_target, kernel=self.kernel, **self.kwargs)
341341

342-
return _reduce(l, self.reduction)
342+
return reduce_tensor(l, self.reduction)

‎piqa/gmsd.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424

25-
from piqa.utils import _jit, _assert_type, _reduce
25+
from piqa.utils import _jit, assert_type, reduce_tensor
2626
from piqa.utils.color import ColorConv
2727
from piqa.utils.functional import (
2828
prewitt_kernel,
@@ -219,7 +219,7 @@ def forward(
219219
r"""Defines the computation performed at every call.
220220
"""
221221

222-
_assert_type(
222+
assert_type(
223223
[input, target],
224224
device=self.kernel.device,
225225
dim_range=(4, 4),
@@ -239,7 +239,7 @@ def forward(
239239
# GMSD
240240
l = gmsd(input, target, kernel=self.kernel, **self.kwargs)
241241

242-
return _reduce(l, self.reduction)
242+
return reduce_tensor(l, self.reduction)
243243

244244

245245
class MS_GMSD(nn.Module):
@@ -310,7 +310,7 @@ def forward(
310310
r"""Defines the computation performed at every call.
311311
"""
312312

313-
_assert_type(
313+
assert_type(
314314
[input, target],
315315
device=self.kernel.device,
316316
dim_range=(4, 4),
@@ -331,4 +331,4 @@ def forward(
331331
**self.kwargs,
332332
)
333333

334-
return _reduce(l, self.reduction)
334+
return reduce_tensor(l, self.reduction)

‎piqa/haarpsi.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

22-
from piqa.utils import _jit, _assert_type, _reduce
22+
from piqa.utils import _jit, assert_type, reduce_tensor
2323
from piqa.utils.color import ColorConv
2424
from piqa.utils.functional import (
2525
haar_kernel,
@@ -171,7 +171,7 @@ def forward(
171171
r"""Defines the computation performed at every call.
172172
"""
173173

174-
_assert_type(
174+
assert_type(
175175
[input, target],
176176
device=self.convert.device,
177177
dim_range=(4, 4),
@@ -191,4 +191,4 @@ def forward(
191191
# HaarPSI
192192
l = haarpsi(input, target, **self.kwargs)
193193

194-
return _reduce(l, self.reduction)
194+
return reduce_tensor(l, self.reduction)

‎piqa/lpips.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torchvision.models as models
1919
import torch.hub as hub
2020

21-
from piqa.utils import _jit, _assert_type, _reduce
21+
from piqa.utils import _jit, assert_type, reduce_tensor
2222

2323
from typing import Dict, List
2424

@@ -207,7 +207,7 @@ def forward(
207207
r"""Defines the computation performed at every call.
208208
"""
209209

210-
_assert_type(
210+
assert_type(
211211
[input, target],
212212
device=self.shift.device,
213213
dim_range=(4, 4),
@@ -232,4 +232,4 @@ def forward(
232232

233233
l = torch.stack(residuals, dim=-1).sum(dim=-1)
234234

235-
return _reduce(l, self.reduction)
235+
return reduce_tensor(l, self.reduction)

‎piqa/mdsi.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.nn as nn
1717
import torch.nn.functional as F
1818

19-
from piqa.utils import _jit, _assert_type, _reduce
19+
from piqa.utils import _jit, assert_type, reduce_tensor
2020
from piqa.utils.color import ColorConv
2121
from piqa.utils.functional import (
2222
prewitt_kernel,
@@ -178,7 +178,7 @@ def forward(
178178
r"""Defines the computation performed at every call.
179179
"""
180180

181-
_assert_type(
181+
assert_type(
182182
[input, target],
183183
device=self.kernel.device,
184184
dim_range=(4, 4),
@@ -202,4 +202,4 @@ def forward(
202202
# MDSI
203203
l = mdsi(input, target, kernel=self.kernel, **self.kwargs)
204204

205-
return _reduce(l, self.reduction)
205+
return reduce_tensor(l, self.reduction)

‎piqa/psnr.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from piqa.utils import _jit, _assert_type, _reduce
12+
from piqa.utils import _jit, assert_type, reduce_tensor
1313

1414

1515
@_jit
@@ -109,7 +109,7 @@ def forward(
109109
r"""Defines the computation performed at every call.
110110
"""
111111

112-
_assert_type(
112+
assert_type(
113113
[input, target],
114114
device=input.device,
115115
dim_range=(1, -1),
@@ -118,4 +118,4 @@ def forward(
118118

119119
l = psnr(input, target, **self.kwargs)
120120

121-
return _reduce(l, self.reduction)
121+
return reduce_tensor(l, self.reduction)

‎piqa/ssim.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424

25-
from piqa.utils import _jit, _assert_type, _reduce
25+
from piqa.utils import _jit, assert_type, reduce_tensor
2626
from piqa.utils.functional import (
2727
gaussian_kernel,
2828
kernel_views,
@@ -251,7 +251,7 @@ def forward(
251251
r"""Defines the computation performed at every call.
252252
"""
253253

254-
_assert_type(
254+
assert_type(
255255
[input, target],
256256
device=self.kernel.device,
257257
dim_range=(3, -1),
@@ -261,7 +261,7 @@ def forward(
261261

262262
l = ssim(input, target, kernel=self.kernel, **self.kwargs)[0]
263263

264-
return _reduce(l, self.reduction)
264+
return reduce_tensor(l, self.reduction)
265265

266266

267267
class MS_SSIM(nn.Module):
@@ -331,7 +331,7 @@ def forward(
331331
r"""Defines the computation performed at every call.
332332
"""
333333

334-
_assert_type(
334+
assert_type(
335335
[input, target],
336336
device=self.kernel.device,
337337
dim_range=(4, 4),
@@ -347,4 +347,4 @@ def forward(
347347
**self.kwargs,
348348
)
349349

350-
return _reduce(l, self.reduction)
350+
return reduce_tensor(l, self.reduction)

‎piqa/tv.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from piqa.utils import _jit, _assert_type, _reduce
12+
from piqa.utils import _jit, assert_type, reduce_tensor
1313

1414

1515
@_jit
@@ -94,8 +94,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
9494
r"""Defines the computation performed at every call.
9595
"""
9696

97-
_assert_type([input], device=input.device, dim_range=(3, -1))
97+
assert_type([input], device=input.device, dim_range=(3, -1))
9898

9999
l = tv(input, **self.kwargs)
100100

101-
return _reduce(l, self.reduction)
101+
return reduce_tensor(l, self.reduction)

‎piqa/utils/__init__.py‎

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,25 @@
1313
_jit = lambda f: f
1414

1515

16-
def _debug(mode: bool = __debug__) -> bool:
17-
r"""Returns whether debugging is enabled or not.
16+
__piqa_debug__ = __debug__
17+
18+
def set_debug(mode: bool = False) -> bool:
19+
r"""Sets and returns whether debugging is enabled or not.
20+
If `__debug__` is `False`, this function has not effect.
21+
22+
Example:
23+
>>> set_debug(False)
24+
False
1825
"""
1926

20-
return mode
27+
global __piqa_debug__
28+
29+
__piqa_debug__ = __debug__ and mode
30+
31+
return __piqa_debug__
2132

2233

23-
def _assert_type(
34+
def assert_type(
2435
tensors: List[torch.Tensor],
2536
device: torch.device,
2637
dim_range: Tuple[int, int] = (0, -1),
@@ -33,60 +44,60 @@ def _assert_type(
3344
Example:
3445
>>> x = torch.rand(5, 3, 256, 256)
3546
>>> y = torch.rand(5, 3, 256, 256)
36-
>>> _assert_type([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
47+
>>> assert_type([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
3748
"""
3849

39-
if not _debug():
50+
if not __piqa_debug__:
4051
return
4152

4253
ref = tensors[0]
4354

4455
for t in tensors:
4556
assert t.device == device, (
46-
f'Expected tensors to be on {device}, got {t.device}'
57+
f'Tensors expected to be on {device}, got {t.device}'
4758
)
4859

4960
assert t.shape == ref.shape, (
50-
'Expected tensors to be of the same shape, got'
61+
'Tensors expected to be of the same shape, got'
5162
f' {ref.shape} and {t.shape}'
5263
)
5364

5465
if dim_range[0] == dim_range[1]:
5566
assert t.dim() == dim_range[0], (
56-
'Expected number of dimensions to be'
67+
'Number of dimensions expected to be'
5768
f' {dim_range[0]}, got {t.dim()}'
5869
)
5970
elif dim_range[0] < dim_range[1]:
6071
assert dim_range[0] <= t.dim() <= dim_range[1], (
61-
'Expected number of dimensions to be between'
72+
'Number of dimensions expected to be between'
6273
f' {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
6374
)
6475
elif dim_range[0] > 0:
6576
assert dim_range[0] <= t.dim(), (
66-
'Expected number of dimensions to be greater or equal to'
77+
'Number of dimensions expected to be greater or equal to'
6778
f' {dim_range[0]}, got {t.dim()}'
6879
)
6980

7081
if n_channels > 0:
7182
assert t.size(1) == n_channels, (
72-
'Expected number of channels to be'
83+
'Number of channels expected to be'
7384
f' {n_channels}, got {t.size(1)}'
7485
)
7586

7687
if value_range[0] < value_range[1]:
7788
assert value_range[0] <= t.min(), (
78-
'Expected values to be greater or equal to'
89+
'Values expected to be greater or equal to'
7990
f' {value_range[0]}, got {t.min()}'
8091
)
8192

8293
assert t.max() <= value_range[1], (
83-
'Expected values to be lower or equal to'
94+
'Values expected to be lower or equal to'
8495
f' {value_range[1]}, got {t.max()}'
8596
)
8697

8798

8899
@_jit
89-
def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
100+
def reduce_tensor(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
90101
r"""Returns the reduction of \(x\).
91102
92103
Args:
@@ -96,7 +107,7 @@ def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
96107
97108
Example:
98109
>>> x = torch.arange(5)
99-
>>> _reduce(x, reduction='sum')
110+
>>> reduce_tensor(x, reduction='sum')
100111
tensor(10)
101112
"""
102113

0 commit comments

Comments
 (0)