Skip to content

Commit d32277b

Browse files
committed
fix: add convolve to utils. Needed in svls reg
1 parent ac94134 commit d32277b

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
__all__ = ["filter2D", "gaussian", "gaussian_kernel2d"]
5+
6+
7+
def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
8+
"""Convolves a given kernel on input tensor without losing dimensional shape.
9+
10+
Parameters
11+
----------
12+
input_tensor : torch.Tensor
13+
Input image/tensor.
14+
kernel : torch.Tensor
15+
Convolution kernel/window.
16+
17+
Returns
18+
-------
19+
torch.Tensor:
20+
The convolved tensor of same shape as the input.
21+
"""
22+
(_, channel, _, _) = input_tensor.size()
23+
24+
# "SAME" padding to avoid losing height and width
25+
pad = [
26+
kernel.size(2) // 2,
27+
kernel.size(2) // 2,
28+
kernel.size(3) // 2,
29+
kernel.size(3) // 2,
30+
]
31+
pad_tensor = F.pad(input_tensor, pad, "replicate")
32+
33+
out = F.conv2d(pad_tensor, kernel, groups=channel)
34+
return out
35+
36+
37+
def gaussian(
38+
window_size: int, sigma: float, device: torch.device = None
39+
) -> torch.Tensor:
40+
"""Create a gaussian 1D tensor.
41+
42+
Parameters
43+
----------
44+
window_size : int
45+
Number of elements for the output tensor.
46+
sigma : float
47+
Std of the gaussian distribution.
48+
device : torch.device
49+
Device for the tensor.
50+
51+
Returns
52+
-------
53+
torch.Tensor:
54+
A gaussian 1D tensor. Shape: (window_size, ).
55+
"""
56+
x = torch.arange(window_size, device=device).float() - window_size // 2
57+
if window_size % 2 == 0:
58+
x = x + 0.5
59+
60+
gauss = torch.exp((-x.pow(2.0) / float(2 * sigma**2)))
61+
62+
return gauss / gauss.sum()
63+
64+
65+
def gaussian_kernel2d(
66+
window_size: int, sigma: float, n_channels: int = 1, device: torch.device = None
67+
) -> torch.Tensor:
68+
"""Create 2D window_size**2 sized kernel a gaussial kernel.
69+
70+
Parameters
71+
----------
72+
window_size : int
73+
Number of rows and columns for the output tensor.
74+
sigma : float
75+
Std of the gaussian distribution.
76+
n_channel : int
77+
Number of channels in the image that will be convolved with
78+
this kernel.
79+
device : torch.device
80+
Device for the kernel.
81+
82+
Returns:
83+
-----------
84+
torch.Tensor:
85+
A tensor of shape (1, 1, window_size, window_size)
86+
"""
87+
kernel_x = gaussian(window_size, sigma, device=device)
88+
kernel_y = gaussian(window_size, sigma, device=device)
89+
90+
kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
91+
kernel_2d = kernel_2d.expand(n_channels, 1, window_size, window_size)
92+
93+
return kernel_2d

0 commit comments

Comments
 (0)