Skip to content

Commit 8c59cf0

Browse files
authored
Merge pull request #17 from alexhagen/master
adaptively moves offset to the right device so that gpu can be used
2 parents a662103 + 970c11e commit 8c59cf0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def fft_conv(
5656
kernel: Tensor,
5757
bias: Tensor = None,
5858
padding: Union[int, Iterable[int]] = 0,
59+
padding_mode: str = 'constant',
5960
stride: Union[int, Iterable[int]] = 1,
6061
dilation: Union[int, Iterable[int]] = 1,
6162
groups: int = 1,
@@ -83,7 +84,7 @@ def fft_conv(
8384
dilation_ = to_ntuple(dilation, n=n)
8485

8586
# internal dilation offsets
86-
offset = torch.zeros(1, 1, *dilation_)
87+
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
8788
offset[(slice(None), slice(None), *((0,) * n))] = 1.0
8889

8990
# correct the kernel by cutting off unwanted dilation trailing zeros
@@ -94,7 +95,7 @@ def fft_conv(
9495

9596
# Pad the input signal & kernel tensors
9697
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
97-
signal = f.pad(signal, signal_padding)
98+
signal = f.pad(signal, signal_padding, mode=padding_mode)
9899

99100
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
100101
# have *even* length. Just pad with one more zero if the final dimension is odd.
@@ -143,6 +144,7 @@ def __init__(
143144
out_channels: int,
144145
kernel_size: Union[int, Iterable[int]],
145146
padding: Union[int, Iterable[int]] = 0,
147+
padding_mode: str = 'constant',
146148
stride: Union[int, Iterable[int]] = 1,
147149
dilation: Union[int, Iterable[int]] = 1,
148150
groups: int = 1,
@@ -164,6 +166,7 @@ def __init__(
164166
self.out_channels = out_channels
165167
self.kernel_size = kernel_size
166168
self.padding = padding
169+
self.padding_mode = padding_mode
167170
self.stride = stride
168171
self.dilation = dilation
169172
self.groups = groups
@@ -192,6 +195,7 @@ def forward(self, signal):
192195
self.weight,
193196
bias=self.bias,
194197
padding=self.padding,
198+
padding_mode=self.padding_mode,
195199
stride=self.stride,
196200
dilation=self.dilation,
197201
groups=self.groups,

0 commit comments

Comments
 (0)