Skip to content

Commit 85f114f

Browse files
📝 Prevent irrelevant docstrings
1 parent c9def95 commit 85f114f

File tree

7 files changed

+31
-0
lines changed

7 files changed

+31
-0
lines changed

spiq/gmsd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class GMSD(nn.Module):
8787
"""
8888

8989
def __init__(self, reduction: str = 'mean', **kwargs):
90+
r""""""
9091
super().__init__()
9192

9293
self.reduce = build_reduce(reduction)
@@ -97,6 +98,9 @@ def forward(
9798
input: torch.Tensor,
9899
target: torch.Tensor,
99100
) -> torch.Tensor:
101+
r"""Defines the computation performed at every call.
102+
"""
103+
100104
l = gmsd(input, target, **self.kwargs)
101105

102106
return self.reduce(l)

spiq/lpips.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
reduction: str = 'mean',
4949
trainable: bool = False,
5050
):
51+
r""""""
5152
super().__init__()
5253

5354
# ImageNet scaling
@@ -97,6 +98,9 @@ def forward(
9798
input: torch.Tensor,
9899
target: torch.Tensor,
99100
) -> torch.Tensor:
101+
r"""Defines the computation performed at every call.
102+
"""
103+
100104
if self.scaling:
101105
input = (input - self.shift) / self.scale
102106
target = (target - self.shift) / self.scale

spiq/mdsi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class MDSI(nn.Module):
125125
"""
126126

127127
def __init__(self, reduction: str = 'mean', **kwargs):
128+
r""""""
128129
super().__init__()
129130

130131
self.reduce = build_reduce(reduction)
@@ -135,6 +136,9 @@ def forward(
135136
input: torch.Tensor,
136137
target: torch.Tensor,
137138
) -> torch.Tensor:
139+
r"""Defines the computation performed at every call.
140+
"""
141+
138142
l = mdsi(input, target, **self.kwargs)
139143

140144
return self.reduce(l)

spiq/psnr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class PSNR(nn.Module):
5555
"""
5656

5757
def __init__(self, reduction: str = 'mean', **kwargs):
58+
r""""""
5859
super().__init__()
5960

6061
self.reduce = build_reduce(reduction)
@@ -68,6 +69,9 @@ def forward(
6869
input: torch.Tensor,
6970
target: torch.Tensor,
7071
) -> torch.Tensor:
72+
r"""Defines the computation performed at every call.
73+
"""
74+
7175
l = psnr(
7276
input.unsqueeze(1).flatten(1),
7377
target.unsqueeze(1).flatten(1),

spiq/ssim.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
reduction: str = 'mean',
199199
**kwargs,
200200
):
201+
r""""""
201202
super().__init__()
202203

203204
self.register_buffer('window', create_window(window_size, n_channels))
@@ -210,6 +211,9 @@ def forward(
210211
input: torch.Tensor,
211212
target: torch.Tensor,
212213
) -> torch.Tensor:
214+
r"""Defines the computation performed at every call.
215+
"""
216+
213217
l = ssim_per_channel(
214218
input,
215219
target,
@@ -240,6 +244,9 @@ def forward(
240244
input: torch.Tensor,
241245
target: torch.Tensor,
242246
) -> torch.Tensor:
247+
r"""Defines the computation performed at every call.
248+
"""
249+
243250
l = msssim_per_channel(
244251
input,
245252
target,

spiq/tv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@ class TV(nn.Module):
4646
"""
4747

4848
def __init__(self, reduction: str = 'mean', **kwargs):
49+
r""""""
4950
super().__init__()
5051

5152
self.reduce = build_reduce(reduction)
5253
self.kwargs = kwargs
5354

5455
def forward(self, input: torch.Tensor) -> torch.Tensor:
56+
r"""Defines the computation performed at every call.
57+
"""
58+
5559
l = tv(input, **self.kwargs)
5660

5761
return self.reduce(l)

spiq/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,16 @@ class Intermediary(nn.Module):
180180
"""
181181

182182
def __init__(self, layers: nn.Sequential, targets: List[int]):
183+
r""""""
183184
super().__init__()
184185

185186
self.layers = layers
186187
self.targets = set(targets)
187188

188189
def forward(self, input: torch.Tensor) -> List[torch.Tensor]:
190+
r"""Defines the computation performed at every call.
191+
"""
192+
189193
output = []
190194

191195
for i, layer in enumerate(self.layers):

0 commit comments

Comments
 (0)