-
Notifications
You must be signed in to change notification settings - Fork 14
Description
I noticed as part of working on #277 that the gradient checks on the PyTorch implementations of the precompute transforms are very slow and infact end up constituting a significant proportion of the overall test suite run time. With these checks removed I can run the whole test suite locally, distributing across 4 processes with pytest_xdist, in 7 minutes, compared to 55 minutes with these checks included.
Given how long these checks take, it might make sense to factor them out in to separate tests and apply a mark to them so they can be skipped when running the tests on pull requests and only run them when merging to main and in the scheduled runs.
s2fft/tests/test_spherical_precompute.py
Lines 74 to 88 in 1d5fa15
| # Test Gradients | |
| flm_grad_test = torch.from_numpy(flm) | |
| flm_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| inverse, | |
| ( | |
| flm_grad_test, | |
| L, | |
| spin, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| ), | |
| ) |
s2fft/tests/test_spherical_precompute.py
Lines 135 to 149 in 1d5fa15
| # Test Gradients | |
| flm_grad_test = torch.from_numpy(flm) | |
| flm_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| inverse, | |
| ( | |
| flm_grad_test, | |
| L, | |
| 0, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| nside, | |
| ), |
s2fft/tests/test_spherical_precompute.py
Lines 203 to 217 in 1d5fa15
| # Test Gradients | |
| f_grad_test = torch.from_numpy(f) | |
| f_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| forward, | |
| ( | |
| f_grad_test, | |
| L, | |
| spin, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| ), | |
| ) |
s2fft/tests/test_spherical_precompute.py
Lines 267 to 283 in 1d5fa15
| # Test Gradients | |
| f_grad_test = torch.from_numpy(f) | |
| f_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| forward, | |
| ( | |
| f_grad_test, | |
| L, | |
| 0, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| nside, | |
| iter, | |
| ), | |
| ) |
s2fft/tests/test_wigner_precompute.py
Lines 61 to 75 in 1d5fa15
| # Test Gradients | |
| flmn_grad_test = torch.from_numpy(flmn) | |
| flmn_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| inverse, | |
| ( | |
| flmn_grad_test, | |
| L, | |
| N, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| ), | |
| ) |
s2fft/tests/test_wigner_precompute.py
Lines 122 to 136 in 1d5fa15
| # Test Gradients | |
| f_grad_test = torch.from_numpy(f) | |
| f_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| forward, | |
| ( | |
| f_grad_test, | |
| L, | |
| N, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| ), | |
| ) |
s2fft/tests/test_wigner_precompute.py
Lines 178 to 193 in 1d5fa15
| # Test Gradients | |
| flmn_grad_test = torch.from_numpy(flmn) | |
| flmn_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| inverse, | |
| ( | |
| flmn_grad_test, | |
| L, | |
| N, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| nside, | |
| ), | |
| ) |
s2fft/tests/test_wigner_precompute.py
Lines 237 to 252 in 1d5fa15
| # Test Gradients | |
| f_grad_test = torch.from_numpy(f) | |
| f_grad_test.requires_grad = True | |
| assert torch.autograd.gradcheck( | |
| forward, | |
| ( | |
| f_grad_test, | |
| L, | |
| N, | |
| torch.from_numpy(kernel), | |
| sampling, | |
| reality, | |
| method, | |
| nside, | |
| ), | |
| ) |