Skip to content

Commit 0ff3bd3

Browse files
committed
initializing on the right device and with the right data type
1 parent 98d252e commit 0ff3bd3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def fft_conv(
8484
dilation_ = to_ntuple(dilation, n=n)
8585

8686
# internal dilation offsets
87-
offset = torch.zeros(1, 1, *dilation_).to(signal.device)
87+
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
8888
offset[(slice(None), slice(None), *((0,) * n))] = 1.0
8989

9090
# correct the kernel by cutting off unwanted dilation trailing zeros

0 commit comments

Comments
 (0)