Skip to content

Commit 45eeae6

Browse files
committed
Fix some test bugs
1 parent 22e21c7 commit 45eeae6

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

tests/test_functional.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,20 @@ def test_fft_conv_functional(
4848
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
4949
w0 = torch.randn(out_channels, in_channels // groups, *kernel_size,
5050
requires_grad=True)
51-
w1 = w0.detach().clone().requires_grad()
51+
w1 = w0.detach().clone().requires_grad_()
52+
53+
b0 = torch.randn(out_channels, requires_grad=True) if bias else None
54+
b1 = b0.detach().clone().requires_grad_() if bias else None
5255

53-
y0 = fft_conv(signal, w0, **kwargs)
54-
y1 = torch_conv(signal, w1, **kwargs)
56+
kwargs = dict(
57+
padding=padding,
58+
stride=stride,
59+
dilation=dilation,
60+
groups=groups,
61+
)
62+
63+
y0 = fft_conv(signal, w0, bias=b0, **kwargs)
64+
y1 = torch_conv(signal, w1, bias=b1, **kwargs)
5565

5666
_assert_almost_equal(y0, y1)
5767

tests/test_module.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,27 @@ def test_fft_conv_module(
3939
kernel_size=kernel_size,
4040
padding=padding,
4141
stride=stride,
42-
dilation=1,
42+
dilation=dilation,
4343
groups=groups,
4444
bias=bias,
4545
ndim=ndim,
4646
)
4747
batch_size = 2 # TODO: Make this non-constant?
4848
dims = ndim * [input_size]
4949
signal = torch.randn(batch_size, in_channels, *dims)
50+
51+
weight = fft_conv_layer.weight
52+
bias = fft_conv_layer.bias
53+
5054
kwargs = dict(
51-
bias=fft_conv_layer.bias,
5255
padding=padding,
5356
stride=stride,
5457
dilation=dilation,
5558
groups=groups,
5659
)
5760

5861
y0 = fft_conv_layer(signal)
59-
y1 = torch_conv(signal, fft_conv_layer._weight, **kwargs)
62+
y1 = torch_conv(signal, weight, bias=bias, **kwargs)
6063

6164
_assert_almost_equal(y0, y1)
6265

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
def _assert_almost_equal(x: Tensor, y: Tensor) -> bool:
66
abs_error = torch.abs(x - y)
7-
assert abs_error.mean().item() < 1e-5
7+
assert abs_error.mean().item() < 5e-5
88
assert abs_error.max().item() < 1e-4
99
return True
1010

0 commit comments

Comments
 (0)