Skip to content

Commit c18e4b9

Browse files
authored
Merge pull request #12 from aretor/master
Add dilation parameter and tests
2 parents eb4f677 + 45fa3a2 commit c18e4b9

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
kernel_size: Union[int, Iterable[int]],
131131
padding: Union[int, Iterable[int]] = 0,
132132
stride: Union[int, Iterable[int]] = 1,
133+
dilation: Union[int, Iterable[int]] = 1,
133134
groups: int = 1,
134135
bias: bool = True,
135136
ndim: int = 1,
@@ -150,6 +151,7 @@ def __init__(
150151
self.kernel_size = kernel_size
151152
self.padding = padding
152153
self.stride = stride
154+
self.dilation = dilation
153155
self.groups = groups
154156
self.use_bias = bias
155157

@@ -165,9 +167,22 @@ def __init__(
165167
)
166168

167169
kernel_size = to_ntuple(kernel_size, ndim)
168-
self.weight = nn.Parameter(
169-
torch.randn(out_channels, in_channels // groups, *kernel_size)
170+
dilation = to_ntuple(dilation, ndim)
171+
total_size = tuple(
172+
((ks - 1) * dil + 1)
173+
for ks, dil in zip(kernel_size, dilation)
170174
)
175+
weight = torch.zeros(out_channels, in_channels // groups, *total_size)
176+
fill = torch.randn(out_channels, in_channels // groups, *kernel_size)
177+
ids = tuple(
178+
torch.arange(0, tot_sz, dil)
179+
for tot_sz, dil in zip(total_size, dilation)
180+
)
181+
182+
# workaround bc PyTorch doesn't support [:, :, <tensor tuple>] indexing
183+
weight[(slice(None), slice(None),) + torch.meshgrid(*ids)] = fill
184+
185+
self.weight = nn.Parameter(weight)
171186
self.bias = nn.Parameter(torch.randn(out_channels)) if bias else None
172187

173188
def forward(self, signal):

tests/test_fft_conv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _gcd(x: int, y: int) -> int:
2727
@pytest.mark.parametrize("kernel_size", [1, 2, 3])
2828
@pytest.mark.parametrize("padding", [0, 1])
2929
@pytest.mark.parametrize("stride", [1, 2, 3])
30+
@pytest.mark.parametrize("dilation", [1, 2, 3])
3031
@pytest.mark.parametrize("bias", [True, False])
3132
@pytest.mark.parametrize("ndim", [1, 2, 3])
3233
@pytest.mark.parametrize("input_size", [7, 8])
@@ -36,6 +37,7 @@ def test_fft_conv(
3637
kernel_size: Union[int, Iterable[int]],
3738
padding: Union[int, Iterable[int]],
3839
stride: Union[int, Iterable[int]],
40+
dilation: Union[int, Iterable[int]],
3941
groups: int,
4042
bias: bool,
4143
ndim: int,
@@ -49,6 +51,7 @@ def test_fft_conv(
4951
kernel_size=kernel_size,
5052
padding=padding,
5153
stride=stride,
54+
dilation=dilation,
5255
groups=groups,
5356
bias=bias,
5457
ndim=ndim,

0 commit comments

Comments
 (0)