|
24 | 24 | prewitt_kernel, |
25 | 25 | gradient_kernel, |
26 | 26 | channel_conv, |
| 27 | + l2_norm, |
27 | 28 | ) |
28 | 29 |
|
29 | 30 |
|
@@ -75,8 +76,8 @@ def gmsd( |
75 | 76 | # Gradient magnitude |
76 | 77 | pad = kernel.size(-1) // 2 |
77 | 78 |
|
78 | | - gm_x = torch.linalg.norm(channel_conv(x, kernel, padding=pad), dim=1) |
79 | | - gm_y = torch.linalg.norm(channel_conv(y, kernel, padding=pad), dim=1) |
| 79 | + gm_x = l2_norm(channel_conv(x, kernel, padding=pad), dims=[1]) |
| 80 | + gm_y = l2_norm(channel_conv(y, kernel, padding=pad), dims=[1]) |
80 | 81 |
|
81 | 82 | gm_xy = gm_x * gm_y |
82 | 83 |
|
@@ -109,7 +110,8 @@ def ms_gmsd( |
109 | 110 | without color space conversion. |
110 | 111 |
|
111 | 112 | .. math:: |
112 | | - \text{MS-GMSD}(x, y) = \sum^{M}_{i = 1} w_i \text{GMSD}(x^i, y^i) |
| 113 | + \text{MS-GMSD}(x, y) = |
| 114 | + \sqrt{\sum^{M}_{i = 1} w_i \text{GMSD}(x^i, y^i) ** 2} |
113 | 115 |
|
114 | 116 | where :math:`x^i` and :math:`y^i` are obtained by downsampling |
115 | 117 | the initial tensors by a factor :math:`2^{i - 1}`. |
@@ -150,8 +152,8 @@ def ms_gmsd( |
150 | 152 | c=c, alpha=alpha, |
151 | 153 | )) |
152 | 154 |
|
153 | | - msgmsd = torch.stack(gmsds, dim=-1) ** 2 |
154 | | - msgmsd = torch.sqrt((msgmsd * weights).sum(dim=-1)) |
| 155 | + msgmsd = weights * torch.stack(gmsds, dim=-1) ** 2 |
| 156 | + msgmsd = msgmsd.sum(dim=-1).sqrt() |
155 | 157 |
|
156 | 158 | return msgmsd |
157 | 159 |
|
|
0 commit comments