1- """Feature Similarity (FSIM)
1+ r """Feature Similarity (FSIM)
22
33This module implements the FSIM in PyTorch.
44
5- Credits :
6- Inspired by the [official implementation]( https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm)
5+ Original :
6+ https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm
77
88References:
9- [1] FSIM: A Feature Similarity Index for Image Quality Assessment
10- (Zhang et al., 2011)
11- https://ieeexplore.ieee.org/document/5705575
9+ .. [Zhang2011] FSIM: A Feature Similarity Index for Image Quality Assessment (Zhang et al., 2011)
1210
13- [2] Image Features From Phase Congruency
14- (Kovesi, 1999)
15- https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.4.1641
11+ .. [Kovesi1999] Image Features From Phase Congruency (Kovesi, 1999)
1612"""
1713
1814import math
2117import torch .nn as nn
2218import torch .nn .functional as F
2319
24- from piqa .utils import _jit , assert_type , reduce_tensor
25- from piqa .utils .color import ColorConv
26- from piqa .utils .functional import (
20+ from torch import Tensor
21+
22+ from .utils import _jit , assert_type , reduce_tensor
23+ from .utils import complex as cx
24+ from .utils .color import ColorConv
25+ from .utils .functional import (
2726 scharr_kernel ,
2827 gradient_kernel ,
2928 filter_grid ,
3029 log_gabor ,
3130 channel_conv ,
3231)
3332
34- import piqa .utils .complex as cx
35-
3633
3734@_jit
3835def fsim (
39- x : torch . Tensor ,
40- y : torch . Tensor ,
41- pc_x : torch . Tensor ,
42- pc_y : torch . Tensor ,
43- kernel : torch . Tensor ,
36+ x : Tensor ,
37+ y : Tensor ,
38+ pc_x : Tensor ,
39+ pc_y : Tensor ,
40+ kernel : Tensor ,
4441 value_range : float = 1. ,
4542 t1 : float = 0.85 ,
4643 t2 : float = 160. / (255. ** 2 ),
4744 t3 : float = 200. / (255. ** 2 ),
4845 t4 : float = 200. / (255. ** 2 ),
4946 lmbda : float = 0.03 ,
50- ) -> torch . Tensor :
51- r"""Returns the FSIM between \(x\) and \(y\) ,
47+ ) -> Tensor :
48+ r"""Returns the FSIM between :math:`x` and :math:`y` ,
5249 without color space conversion and downsampling.
5350
5451 Args:
55- x: An input tensor, \(( N, 3 \text{ or } 1, H, W)\) .
56- y: A target tensor, \(( N, 3 \text{ or } 1, H, W)\) .
57- pc_x: The input phase congruency, \(( N, H, W)\) .
58- pc_y: The target phase congruency, \(( N, H, W)\) .
59- kernel: A gradient kernel, \(( 2, 1, K, K)\) .
60- value_range: The value range \(L\) of the inputs (usually 1. or 255).
52+ x: An input tensor, :math:`( N, 3 \text{ or } 1, H, W)` .
53+ y: A target tensor, :math:`( N, 3 \text{ or } 1, H, W)` .
54+ pc_x: The input phase congruency, :math:`( N, H, W)` .
55+ pc_y: The target phase congruency, :math:`( N, H, W)` .
56+ kernel: A gradient kernel, :math:`( 2, 1, K, K)` .
57+ value_range: The value range :math:`L` of the inputs (usually `1.` or ` 255` ).
6158
62- For the remaining arguments, refer to [1].
59+ Note:
60+ For the remaining arguments, refer to [Zhang2011]_.
6361
6462 Returns:
65- The FSIM vector, \(( N,)\) .
63+ The FSIM vector, :math:`( N,)` .
6664
6765 Example:
6866 >>> x = torch.rand(5, 3, 256, 256)
@@ -118,25 +116,26 @@ def fsim(
118116
119117@_jit
120118def pc_filters (
121- x : torch . Tensor ,
119+ x : Tensor ,
122120 scales : int = 4 ,
123121 orientations : int = 4 ,
124122 wavelength : float = 6. ,
125123 factor : float = 2. ,
126124 sigma_f : float = 0.5978 , # -log(0.55)
127125 sigma_theta : float = 0.6545 , # pi / (4 * 1.2)
128- ) -> torch . Tensor :
129- r"""Returns the log-Gabor filters for `phase_congruency`.
126+ ) -> Tensor :
127+ r"""Returns the log-Gabor filters for :func: `phase_congruency`.
130128
131129 Args:
132- x: An input tensor, \(( *, H, W)\) .
133- scales: The number of scales, \( S_1\) .
134- orientations: The number of orientations, \( S_2\) .
130+ x: An input tensor, :math:`( *, H, W)` .
131+ scales: The number of scales, :math:` S_1` .
132+ orientations: The number of orientations, :math:` S_2` .
135133
136- For the remaining arguments, refer to [2].
134+ Note:
135+ For the remaining arguments, refer to [Kovesi1999]_.
137136
138137 Returns:
139- The filters tensor, \(( S_1, S_2, H, W)\) .
138+ The filters tensor, :math:`( S_1, S_2, H, W)` .
140139 """
141140
142141 r , theta = filter_grid (x )
@@ -177,24 +176,25 @@ def pc_filters(
177176
178177@_jit
179178def phase_congruency (
180- x : torch . Tensor ,
181- filters : torch . Tensor ,
179+ x : Tensor ,
180+ filters : Tensor ,
182181 value_range : float = 1. ,
183182 k : float = 2. ,
184183 rescale : float = 1.7 ,
185184 eps : float = 1e-8 ,
186- ) -> torch . Tensor :
187- r"""Returns the Phase Congruency (PC) of \(x\) .
185+ ) -> Tensor :
186+ r"""Returns the Phase Congruency (PC) of :math:`x` .
188187
189188 Args:
190- x: An input tensor, \(( N, 1, H, W)\) .
191- filters: The frequency domain filters, \(( S_1, S_2, H, W)\) .
192- value_range: The value range \(L\) of the input (usually 1. or 255).
189+ x: An input tensor, :math:`( N, 1, H, W)` .
190+ filters: The frequency domain filters, :math:`( S_1, S_2, H, W)` .
191+ value_range: The value range :math:`L` of the input (usually `1.` or ` 255` ).
193192
194- For the remaining arguments, refer to [2].
193+ Note:
194+ For the remaining arguments, refer to [Kovesi1999]_.
195195
196196 Returns:
197- The PC tensor, \(( N, H, W)\) .
197+ The PC tensor, :math:`( N, H, W)` .
198198
199199 Example:
200200 >>> x = torch.rand(5, 1, 256, 256)
@@ -254,18 +254,24 @@ class FSIM(nn.Module):
254254 r"""Creates a criterion that measures the FSIM
255255 between an input and a target.
256256
257- Before applying `fsim`, the input and target are converted from
258- RBG to Y(IQ) and downsampled by a factor \( \ frac{\min(H, W)}{256} \) .
257+ Before applying :func: `fsim`, the input and target are converted from
258+ RBG to Y(IQ) and downsampled by a factor :math:`\ frac{\min(H, W)}{256}` .
259259
260260 Args:
261261 chromatic: Whether to use the chromatic channels (IQ) or not.
262262 downsample: Whether downsampling is enabled or not.
263- kernel: A gradient kernel, \(( 2, 1, K, K)\) .
263+ kernel: A gradient kernel, :math:`( 2, 1, K, K)` .
264264 If `None`, use the Scharr kernel instead.
265265 reduction: Specifies the reduction to apply to the output:
266266 `'none'` | `'mean'` | `'sum'`.
267267
268- `**kwargs` are transmitted to `fsim`.
268+ Note:
269+ `**kwargs` are passed to :func:`fsim`.
270+
271+ Shapes:
272+ input: :math:`(N, 3, H, W)`
273+ target: :math:`(N, 3, H, W)`
274+ output: :math:`(N,)` or :math:`()` depending on `reduction`
269275
270276 Example:
271277 >>> criterion = FSIM().cuda()
@@ -281,11 +287,10 @@ def __init__(
281287 self ,
282288 chromatic : bool = True ,
283289 downsample : bool = True ,
284- kernel : torch . Tensor = None ,
290+ kernel : Tensor = None ,
285291 reduction : str = 'mean' ,
286292 ** kwargs ,
287293 ):
288- r""""""
289294 super ().__init__ ()
290295
291296 if kernel is None :
@@ -300,16 +305,9 @@ def __init__(
300305 self .value_range = kwargs .get ('value_range' , 1. )
301306 self .kwargs = kwargs
302307
303- def forward (
304- self ,
305- input : torch .Tensor ,
306- target : torch .Tensor ,
307- ) -> torch .Tensor :
308- r"""Defines the computation performed at every call.
309- """
310-
308+ def forward (self , input : Tensor , target : Tensor ) -> Tensor :
311309 assert_type (
312- [ input , target ] ,
310+ input , target ,
313311 device = self .kernel .device ,
314312 dim_range = (4 , 4 ),
315313 n_channels = 3 ,
0 commit comments