Skip to content

Commit 43ae9d7

Browse files
🐛 Fix inaccurate LPIPS
The inaccuracies came from the dropout layers. Initializing LPIPS in evaluation mode corrected this behavior.
1 parent 3227b26 commit 43ae9d7

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

piqa/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
specific image quality assessement metric.
66
"""
77

8-
__version__ = '1.0.6'
8+
__version__ = '1.0.7'

piqa/lpips.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class LPIPS(nn.Module):
7575
be scaled w.r.t. ImageNet.
7676
dropout: Whether dropout is used or not.
7777
pretrained: Whether the official pretrained weights are used or not.
78+
eval: Whether to initialize the object in evaluation mode or not.
7879
reduction: Specifies the reduction to apply to the output:
7980
`'none'` | `'mean'` | `'sum'`.
8081
@@ -84,8 +85,7 @@ class LPIPS(nn.Module):
8485
* Output: (N,) or (1,) depending on `reduction`
8586
8687
Note:
87-
`LPIPS` is a *trainable* metric. To prevent the weights from updating,
88-
use the `torch.no_grad()` context or freeze the weights.
88+
`LPIPS` is a *trainable* metric.
8989
9090
Example:
9191
>>> criterion = LPIPS().cuda()
@@ -102,6 +102,7 @@ def __init__(
102102
scaling: bool = True,
103103
dropout: bool = True,
104104
pretrained: bool = True,
105+
eval: bool = True,
105106
reduction: str = 'mean',
106107
):
107108
r""""""
@@ -143,6 +144,9 @@ def __init__(
143144
if pretrained:
144145
self.lin.load_state_dict(get_weights(network=network))
145146

147+
if eval:
148+
self.eval()
149+
146150
self.reduce = build_reduce(reduction)
147151

148152
def forward(

0 commit comments

Comments
 (0)