-
Notifications
You must be signed in to change notification settings - Fork 14
Provide PyTorch implementations by wrapping JAX functions #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| flm = flm.at[:, m_start_ind:].set( | ||
| jnp.einsum("...tlm, ...tm -> ...lm", kernel, ftm, optimize=True) | ||
| jnp.einsum( | ||
| "...tlm, ...tm -> ...lm", kernel.astype(flm.dtype), ftm, optimize=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicit dtype cast here was needed to avoid getting warnings when computing gradients through this function which were causing test failures, which indicated an implicit cast from complex to real types was happening which might be loosing information. I think this was due to the kernel argument here being of real type but as ftm is complex and the flm output of the einsum is complex, kernel will be implicitly cast to complex as part of einsum and in the reverse-pass the derivatives with respect to the kernel will therefore be complex even though we only should retain the real valued component.
|
As discussed in meeting today marking as this as ready for review and requesting review from @kmulderdas. The remaining todos on exposing torch versions of on-the-fly transforms and also at that point documenting wider torch support can be dealt with an a separate PR. |
| methods_to_test = ["numpy", "jax", "torch"] | ||
| recursions_to_test = ["price-mcewen", "risbo", "auto"] | ||
| iter_to_test = [0, 3] | ||
| iter_to_test = [0, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the maximum number of iterations we test for as I noticed the tests with iter=3 where particularly slow, and from a testing perspective just checking the code works with a non-zero number of iterations is sufficient.
| from s2fft.precompute_transforms.wigner import forward, inverse | ||
| from s2fft.sampling import so3_samples as samples | ||
|
|
||
| jax.config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed to ensure we don't get numerical issues due to using single precision when checking gradients in the method="torch" tests, as these now use JAX under the hood.
I've now added PyTorch wrappers for the on-the-fly versions of spherical / Wigner-d transforms and updated docs and tests accordingly. |
kmulderdas
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone over the changes and they look great. I don't see any reason for not merging this PR, although it might be good to have a chat at a later time about the breaking JAX changes (v0.6.0).
This PR removes the current manual reimplementation of the precompute transform and some of the utility functions provided by
s2fftto allow use with PyTorch in favour of wrapping the JAX implementations using JAX and PyTorch's mutual support for the DLPack standard as outlined by Matt Johnson in this Gist.Some local benchmarking suggests there is no performance degradation with this wrapping approach compared to the 'native' implementations beyond the very smallest bandlimits
Land a potential a small constant factor speedup for largerL- see benchmarks results in files belowprecompute-spherical-torch-benchmarks.json
precompute-spherical-torch-wrapper-benchmarks.json
As all imports from
torchare after changes in this PR confined to thes2fft.utils.torch_wrappermodule and the import there is guarded in antry: ... except ImportErrorblock this PR also removestorchfrom the required dependencies for the project, with an informative error message being raised when the user tries to use the wrapper functionality withouttorchbeing installed.Todo
torch_wrappermodules2fft.utils.quadrature_torchands2fft.utils.resampling_torch