Skip to content

Commit 2126924

Browse files
authored
Fix incomplete tests in batch and transforms (#506)
* Fix test_compute_deltas_twochannels * Fix 3batch test helper
1 parent d1adb7f commit 2126924

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

test/test_batch_consistency.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def _test_batch(functional, tensor, *args, **kwargs):
6464
expected = expected.repeat(*ind)
6565

6666
torch.random.manual_seed(42)
67-
_ = functional(tensors.clone(), *args, **kwargs)
67+
computed = functional(tensors.clone(), *args, **kwargs)
68+
69+
assert expected.shape == computed.shape, (expected.shape, computed.shape)
70+
assert torch.allclose(expected, computed, **kwargs_compare)
6871

6972

7073
class TestFunctional(unittest.TestCase):

test/test_transforms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,12 @@ def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8)
212212

213213
def test_compute_deltas_twochannel(self):
214214
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
215-
_ = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
216-
[0.5, 1.0, 1.0, 0.5]]])
217-
transform = transforms.ComputeDeltas()
215+
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
216+
[0.5, 1.0, 1.0, 0.5]]])
217+
transform = transforms.ComputeDeltas(win_length=3)
218218
computed = transform(specgram)
219-
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
219+
assert computed.shape == expected.shape, (computed.shape, expected.shape)
220+
assert torch.allclose(computed, expected, atol=1e-6, rtol=1e-8)
220221

221222

222223
if __name__ == '__main__':

0 commit comments

Comments
 (0)