File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed
Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change 55specific image quality assessement metric.
66"""
77
8- __version__ = '1.0.6 '
8+ __version__ = '1.0.7 '
Original file line number Diff line number Diff line change @@ -75,6 +75,7 @@ class LPIPS(nn.Module):
7575 be scaled w.r.t. ImageNet.
7676 dropout: Whether dropout is used or not.
7777 pretrained: Whether the official pretrained weights are used or not.
78+ eval: Whether to initialize the object in evaluation mode or not.
7879 reduction: Specifies the reduction to apply to the output:
7980 `'none'` | `'mean'` | `'sum'`.
8081
@@ -84,8 +85,7 @@ class LPIPS(nn.Module):
8485 * Output: (N,) or (1,) depending on `reduction`
8586
8687 Note:
87- `LPIPS` is a *trainable* metric. To prevent the weights from updating,
88- use the `torch.no_grad()` context or freeze the weights.
88+ `LPIPS` is a *trainable* metric.
8989
9090 Example:
9191 >>> criterion = LPIPS().cuda()
@@ -102,6 +102,7 @@ def __init__(
102102 scaling : bool = True ,
103103 dropout : bool = True ,
104104 pretrained : bool = True ,
105+ eval : bool = True ,
105106 reduction : str = 'mean' ,
106107 ):
107108 r""""""
@@ -143,6 +144,9 @@ def __init__(
143144 if pretrained :
144145 self .lin .load_state_dict (get_weights (network = network ))
145146
147+ if eval :
148+ self .eval ()
149+
146150 self .reduce = build_reduce (reduction )
147151
148152 def forward (
You can’t perform that action at this time.
0 commit comments