File tree Expand file tree Collapse file tree 7 files changed +31
-0
lines changed
Expand file tree Collapse file tree 7 files changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -87,6 +87,7 @@ class GMSD(nn.Module):
8787 """
8888
8989 def __init__ (self , reduction : str = 'mean' , ** kwargs ):
90+ r""""""
9091 super ().__init__ ()
9192
9293 self .reduce = build_reduce (reduction )
@@ -97,6 +98,9 @@ def forward(
9798 input : torch .Tensor ,
9899 target : torch .Tensor ,
99100 ) -> torch .Tensor :
101+ r"""Defines the computation performed at every call.
102+ """
103+
100104 l = gmsd (input , target , ** self .kwargs )
101105
102106 return self .reduce (l )
Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ def __init__(
4848 reduction : str = 'mean' ,
4949 trainable : bool = False ,
5050 ):
51+ r""""""
5152 super ().__init__ ()
5253
5354 # ImageNet scaling
@@ -97,6 +98,9 @@ def forward(
9798 input : torch .Tensor ,
9899 target : torch .Tensor ,
99100 ) -> torch .Tensor :
101+ r"""Defines the computation performed at every call.
102+ """
103+
100104 if self .scaling :
101105 input = (input - self .shift ) / self .scale
102106 target = (target - self .shift ) / self .scale
Original file line number Diff line number Diff line change @@ -125,6 +125,7 @@ class MDSI(nn.Module):
125125 """
126126
127127 def __init__ (self , reduction : str = 'mean' , ** kwargs ):
128+ r""""""
128129 super ().__init__ ()
129130
130131 self .reduce = build_reduce (reduction )
@@ -135,6 +136,9 @@ def forward(
135136 input : torch .Tensor ,
136137 target : torch .Tensor ,
137138 ) -> torch .Tensor :
139+ r"""Defines the computation performed at every call.
140+ """
141+
138142 l = mdsi (input , target , ** self .kwargs )
139143
140144 return self .reduce (l )
Original file line number Diff line number Diff line change @@ -55,6 +55,7 @@ class PSNR(nn.Module):
5555 """
5656
5757 def __init__ (self , reduction : str = 'mean' , ** kwargs ):
58+ r""""""
5859 super ().__init__ ()
5960
6061 self .reduce = build_reduce (reduction )
@@ -68,6 +69,9 @@ def forward(
6869 input : torch .Tensor ,
6970 target : torch .Tensor ,
7071 ) -> torch .Tensor :
72+ r"""Defines the computation performed at every call.
73+ """
74+
7175 l = psnr (
7276 input .unsqueeze (1 ).flatten (1 ),
7377 target .unsqueeze (1 ).flatten (1 ),
Original file line number Diff line number Diff line change @@ -198,6 +198,7 @@ def __init__(
198198 reduction : str = 'mean' ,
199199 ** kwargs ,
200200 ):
201+ r""""""
201202 super ().__init__ ()
202203
203204 self .register_buffer ('window' , create_window (window_size , n_channels ))
@@ -210,6 +211,9 @@ def forward(
210211 input : torch .Tensor ,
211212 target : torch .Tensor ,
212213 ) -> torch .Tensor :
214+ r"""Defines the computation performed at every call.
215+ """
216+
213217 l = ssim_per_channel (
214218 input ,
215219 target ,
@@ -240,6 +244,9 @@ def forward(
240244 input : torch .Tensor ,
241245 target : torch .Tensor ,
242246 ) -> torch .Tensor :
247+ r"""Defines the computation performed at every call.
248+ """
249+
243250 l = msssim_per_channel (
244251 input ,
245252 target ,
Original file line number Diff line number Diff line change @@ -46,12 +46,16 @@ class TV(nn.Module):
4646 """
4747
4848 def __init__ (self , reduction : str = 'mean' , ** kwargs ):
49+ r""""""
4950 super ().__init__ ()
5051
5152 self .reduce = build_reduce (reduction )
5253 self .kwargs = kwargs
5354
5455 def forward (self , input : torch .Tensor ) -> torch .Tensor :
56+ r"""Defines the computation performed at every call.
57+ """
58+
5559 l = tv (input , ** self .kwargs )
5660
5761 return self .reduce (l )
Original file line number Diff line number Diff line change @@ -180,12 +180,16 @@ class Intermediary(nn.Module):
180180 """
181181
182182 def __init__ (self , layers : nn .Sequential , targets : List [int ]):
183+ r""""""
183184 super ().__init__ ()
184185
185186 self .layers = layers
186187 self .targets = set (targets )
187188
188189 def forward (self , input : torch .Tensor ) -> List [torch .Tensor ]:
190+ r"""Defines the computation performed at every call.
191+ """
192+
189193 output = []
190194
191195 for i , layer in enumerate (self .layers ):
You can’t perform that action at this time.
0 commit comments