diff --git a/README.md b/README.md index 454a692e..74ed156e 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax") For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html). > [!NOTE] -> We also provide PyTorch support for the precompute version of our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). +> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch. ## SSHT & HEALPix wrappers 💡 diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index bec1dd7c..aa14c9cd 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -1,5 +1,6 @@ """Benchmarks for precompute spherical transforms.""" +import jax import numpy as np from benchmarking import ( BenchmarkSetup, @@ -10,6 +11,7 @@ import s2fft import s2fft.precompute_transforms +from s2fft.utils import torch_wrapper L_VALUES = [8, 16, 32, 64, 128, 256] SPIN_VALUES = [0] @@ -31,11 +33,17 @@ def setup_forward(method, L, sampling, spin, reality, recursion): sampling=sampling, reality=reality, ) - kernel_function = ( - s2fft.precompute_transforms.construct.spin_spherical_kernel_jax - if method == "jax" - else s2fft.precompute_transforms.construct.spin_spherical_kernel - ) + # As torch method wraps JAX functions and converting NumPy array to torch tensor + # causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently + # converting to JAX array using mutual DLPack support we first convert the NumPy + # arrays to a JAX arrays before converting to torch tensors which eliminates this + # warning + if method.startswith("jax") or method.startswith("torch"): + flm = jax.numpy.asarray(flm) + f = jax.numpy.asarray(f) + if method.startswith("torch"): + flm, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flm, f)) + kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method] kernel = kernel_function( L=L, spin=spin, @@ -73,11 +81,16 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): skip("Reality only valid for scalar fields (spin=0).") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - kernel_function = ( - s2fft.precompute_transforms.construct.spin_spherical_kernel_jax - if method == "jax" - else s2fft.precompute_transforms.construct.spin_spherical_kernel - ) + # As torch method wraps JAX functions and converting NumPy array to torch tensor + # causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently + # converting to JAX array using mutual DLPack support we first convert the NumPy + # array to a JAX array before converting to a torch tensor which eliminates this + # warning + if method.startswith("jax") or method.startswith("torch"): + flm = jax.numpy.asarray(flm) + if method.startswith("torch"): + flm = torch_wrapper.jax_array_to_torch_tensor(flm) + kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method] kernel = kernel_function( L=L, spin=spin, diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index 73971d43..5da5ec4b 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -1,5 +1,6 @@ """Benchmarks for precompute Wigner-d transforms.""" +import jax import numpy as np from benchmarking import ( BenchmarkSetup, @@ -10,6 +11,7 @@ import s2fft import s2fft.precompute_transforms from s2fft.base_transforms import wigner as base_wigner +from s2fft.utils import torch_wrapper L_VALUES = [16, 32, 64, 128, 256] N_VALUES = [2] @@ -31,11 +33,17 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): sampling=sampling, reality=reality, ) - kernel_function = ( - s2fft.precompute_transforms.construct.wigner_kernel_jax - if "jax" in method - else s2fft.precompute_transforms.construct.wigner_kernel - ) + # As torch method wraps JAX functions and converting NumPy array to torch tensor + # causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently + # converting to JAX array using mutual DLPack support we first convert the NumPy + # arrays to a JAX arrays before converting to torch tensors which eliminates this + # warning + if method.startswith("jax") or method.startswith("torch"): + flmn = jax.numpy.asarray(flmn) + f = jax.numpy.asarray(f) + if method.startswith("torch"): + flmn, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flmn, f)) + kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method] kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode ) @@ -67,11 +75,16 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): def setup_inverse(method, L, N, L_lower, sampling, reality, mode): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) - kernel_function = ( - s2fft.precompute_transforms.construct.wigner_kernel_jax - if method == "jax" - else s2fft.precompute_transforms.construct.wigner_kernel - ) + # As torch method wraps JAX functions and converting NumPy array to torch tensor + # causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently + # converting to JAX array using mutual DLPack support we first convert the NumPy + # arrays to a JAX arrays before converting to torch tensors which eliminates this + # warning + if method.startswith("jax") or method.startswith("torch"): + flmn = jax.numpy.asarray(flmn) + if method.startswith("torch"): + flmn = torch_wrapper.jax_array_to_torch_tensor(flmn) + kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method] kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode ) diff --git a/notebooks/torch_frontend.ipynb b/notebooks/torch_frontend.ipynb index 09514ab8..0aebe2c8 100644 --- a/notebooks/torch_frontend.ipynb +++ b/notebooks/torch_frontend.ipynb @@ -28,27 +28,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models." + "This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - `S2FFT` will emit a warning if this has not been done." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n" - ] - } - ], + "outputs": [], "source": [ + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", "import torch \n", - "import numpy as np \n", - "from s2fft.precompute_transforms.spherical import inverse, forward\n", - "from s2fft.precompute_transforms.construct import spin_spherical_kernel\n", + "import numpy as np\n", + "from s2fft.transforms.spherical import inverse, forward\n", + "from s2fft.precompute_transforms.spherical import (\n", + " inverse as precompute_inverse, forward as precompute_forward\n", + ")\n", + "from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n", "from s2fft.utils import signal_generator" ] }, @@ -65,33 +62,40 @@ "metadata": {}, "outputs": [], "source": [ - "L = 64 # Spherical harmonic bandlimit\n", - "rng = np.random.default_rng(1234951510) # Random seed for signal generator\n", - "flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients" + "L = 64 \n", + "rng = np.random.default_rng(1234951510)\n", + "flm = torch.from_numpy(signal_generator.generate_flm(rng, L))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors." + "Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], "source": [ - "inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n", - "forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) " + "f = inverse(flm, L, method=\"torch\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform" + "To calculate the corresponding spherical harmonic representation execute" ] }, { @@ -100,53 +104,107 @@ "metadata": {}, "outputs": [], "source": [ - "f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")" + "flm_check = forward(f, L, method=\"torch\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To calculate the corresponding spherical harmonic representation execute" + "Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean absolute error = 2.8915048238993476e-14\n" + ] + } + ], "source": [ - "flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")" + "print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, lets check the error on the roundtrip is at 64bit machine precision" + "For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [], + "source": [ + "inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n", + "forward_kernel = spin_spherical_kernel_torch(L, forward=True) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then pass the kernels as additional arguments to the transform functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'orward_kernel' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m precompute_f \u001b[38;5;241m=\u001b[39m precompute_inverse(flm, L, kernel\u001b[38;5;241m=\u001b[39minverse_kernel, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m precompute_flm_check \u001b[38;5;241m=\u001b[39m precompute_forward(f, L, kernel\u001b[38;5;241m=\u001b[39m\u001b[43morward_kernel\u001b[49m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'orward_kernel' is not defined" + ] + } + ], + "source": [ + "precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method=\"torch\")\n", + "precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method=\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, we check the error on the round trip is as expected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Mean absolute error = 1.1866908936078849e-14\n" + "Mean absolute error = 2.8472981477378884e-14\n" ] } ], "source": [ - "print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")" + "print(f\"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}\")" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.4 ('s2fft')", + "display_name": "s2fft", "language": "python", "name": "python3" }, @@ -160,14 +218,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.11.10" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/pyproject.toml b/pyproject.toml index 4af3e929..6e05c936 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,8 @@ classifiers = [ description = "Differentiable and accelerated spherical transforms with JAX" dependencies = [ "numpy>=1.20", - "jax>=0.3.13", + "jax>=0.3.13,<0.6.0", "jaxlib", - "torch", ] dynamic = [ "version", @@ -74,6 +73,10 @@ tests = [ "pytest-cov", "so3", "pyssht", + "torch", +] +torch = [ + "torch", ] [tool.scikit-build] diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index a74a6575..db1ed57f 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -4,11 +4,10 @@ import jax import jax.numpy as jnp import numpy as np -import torch from s2fft import recursions from s2fft.sampling import s2_samples as samples -from s2fft.utils import quadrature, quadrature_jax +from s2fft.utils import quadrature, quadrature_jax, torch_wrapper # Maximum spin number at which Price-McEwen recursion is sufficiently accurate. # For spins > PM_MAX_STABLE_SPIN one should default to the Risbo recursion. @@ -22,7 +21,6 @@ def spin_spherical_kernel( sampling: str = "mw", nside: int = None, forward: bool = True, - using_torch: bool = False, recursion: str = "auto", ) -> np.ndarray: r""" @@ -50,8 +48,6 @@ def spin_spherical_kernel( forward (bool, optional): Whether to provide forward or inverse shift. Defaults to False. - using_torch (bool, optional): Desired frontend functionality. Defaults to False. - recursion (str, optional): Recursion to adopt. Supported recursion schemes include {"auto", "price-mcewen", "risbo"}. Defaults to "auto" which will detect the most appropriate recursion given the parameter configuration. @@ -163,7 +159,7 @@ def spin_spherical_kernel( healpix_phase_shifts(L, nside, forward)[:, m_start_ind:], ) - return torch.from_numpy(dl) if using_torch else dl + return dl def spin_spherical_kernel_jax( @@ -329,6 +325,11 @@ def spin_spherical_kernel_jax( return dl +spin_spherical_kernel_torch = torch_wrapper.wrap_as_torch_function( + spin_spherical_kernel_jax +) + + def wigner_kernel( L: int, N: int, @@ -337,7 +338,6 @@ def wigner_kernel( nside: int = None, forward: bool = False, mode: str = "auto", - using_torch: bool = False, ) -> np.ndarray: r""" Precompute the wigner-d kernel for Wigner transform. @@ -368,8 +368,6 @@ def wigner_kernel( {"auto", "direct", "fft"}. Defaults to "auto" which will detect the most appropriate recursion given the parameter configuration. - using_torch (bool, optional): Desired frontend functionality. Defaults to False. - Returns: np.ndarray: Transform kernel for Wigner transform. @@ -468,7 +466,7 @@ def wigner_kernel( healpix_phase_shifts(L, nside, forward), ) - return torch.from_numpy(dl) if using_torch else dl + return dl def wigner_kernel_jax( @@ -611,6 +609,9 @@ def wigner_kernel_jax( return dl +wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(wigner_kernel_jax) + + def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]: """ Computes Fourier coefficients of the reduced Wigner d-functions and quadrature @@ -667,6 +668,11 @@ def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]: return deltas, w +fourier_wigner_kernel_torch = torch_wrapper.wrap_as_torch_function( + fourier_wigner_kernel_jax +) + + def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray: r""" Generates a phase shift vector for HEALPix for all :math:`\theta` rings. diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 7a1a7726..878f7173 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -4,7 +4,6 @@ import jax.numpy as jnp import numpy as np -import torch from jax import jit from s2fft.precompute_transforms import construct @@ -14,7 +13,7 @@ iterative_refinement, resampling, resampling_jax, - resampling_torch, + torch_wrapper, ) @@ -194,7 +193,7 @@ def inverse_transform_jax( ftm = ftm.at[:, m_start_ind + m_offset :].add( jnp.einsum( "...tlm, ...lm -> ...tm", - kernel, + kernel.astype(ftm.dtype), flm[:, m_start_ind:], optimize=True, ) @@ -222,82 +221,7 @@ def inverse_transform_jax( return jnp.real(f) if reality else f -def inverse_transform_torch( - flm: torch.tensor, - kernel: torch.tensor, - L: int, - sampling: str, - reality: bool, - spin: int, - nside: int, -) -> torch.tensor: - r""" - Compute the inverse spherical harmonic transform via precompute (Torch - implementation). - - Args: - flm (torch.tensor): Spherical harmonic coefficients. - - kernel (torch.tensor): Wigner-d kernel. - - L (int): Harmonic band-limit. - - sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. - - reality (bool, optional): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - spin (int): Harmonic spin. - - nside (int): HEALPix Nside resolution parameter. Only required - if sampling="healpix". - - Returns: - torch.tensor: Pixel-space coefficients with shape. - - """ - m_offset = 1 if sampling in ["mwss", "healpix"] else 0 - m_start_ind = L - 1 if reality else 0 - - ftm = torch.zeros(samples.ftm_shape(L, sampling, nside), dtype=torch.complex128) - if sampling.lower() == "healpix": - ftm[:, m_start_ind + m_offset :] += torch.einsum( - "...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:] - ) - else: - ftm[:, m_start_ind + m_offset :].real += torch.einsum( - "...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:].real - ) - ftm[:, m_start_ind + m_offset :].imag += torch.einsum( - "...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:].imag - ) - ftm *= (-1) ** spin - if reality: - ftm[:, m_offset : m_start_ind + m_offset] = torch.flip( - torch.conj(ftm[:, m_start_ind + m_offset + 1 :]), dims=[-1] - ) - - if sampling.lower() == "healpix": - if reality: - ftm[:, m_offset : m_start_ind + m_offset] = torch.flip( - torch.conj(ftm[:, m_start_ind + m_offset + 1 :]), dims=[-1] - ) - f = hp.healpix_ifft(ftm, L, nside, "torch", reality) - - else: - if reality: - f = torch.fft.irfft( - ftm[:, m_start_ind + m_offset :], - samples.nphi_equiang(L, sampling), - axis=-1, - norm="forward", - ) - else: - f = torch.fft.ifftshift(ftm, dim=[-1]) - f = torch.fft.ifft(f, axis=-1, norm="forward") - - return f.real if reality else f +inverse_transform_torch = torch_wrapper.wrap_as_torch_function(inverse_transform_jax) def forward( @@ -468,7 +392,7 @@ def forward_transform_jax( nside: int, ) -> jnp.ndarray: r""" - Compute the forward spherical harmonic tranclearsform via precompute (vectorized + Compute the forward spherical harmonic transform via precompute (vectorized implementation). Args: @@ -518,7 +442,9 @@ def forward_transform_jax( flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128) flm = flm.at[:, m_start_ind:].set( - jnp.einsum("...tlm, ...tm -> ...lm", kernel, ftm, optimize=True) + jnp.einsum( + "...tlm, ...tm -> ...lm", kernel.astype(flm.dtype), ftm, optimize=True + ) ) if reality: @@ -532,82 +458,7 @@ def forward_transform_jax( return flm * (-1) ** spin -def forward_transform_torch( - f: torch.tensor, - kernel: torch.tensor, - L: int, - sampling: str, - reality: bool, - spin: int, - nside: int, -) -> torch.tensor: - r""" - Compute the forward spherical harmonic tranclearsform via precompute (vectorized - implementation). - - Args: - f (torch.tensor): Signal on the sphere. - - kernel (torch.tensor): Wigner-d kernel. - - L (int): Harmonic band-limit. - - sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. - - reality (bool, optional): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - spin (int): Harmonic spin. - - nside (int): HEALPix Nside resolution parameter. Only required - if sampling="healpix". - - Returns: - torch.tensor: Pixel-space coefficients. - - """ - if sampling.lower() == "mw": - f = resampling_torch.mw_to_mwss(f, L, spin) - - if sampling.lower() in ["mw", "mwss"]: - sampling = "mwss" - f = resampling_torch.upsample_by_two_mwss(f, L, spin) - - m_offset = 1 if sampling in ["mwss", "healpix"] else 0 - m_start_ind = L - 1 if reality else 0 - - if sampling.lower() == "healpix": - ftm = hp.healpix_fft(f, L, nside, "torch", reality)[:, m_offset:] - if reality: - ftm = ftm[:, m_start_ind:] - else: - if reality: - ftm = torch.fft.rfft(torch.real(f), axis=-1, norm="backward") - if m_offset != 0: - ftm = ftm[:, :-1] - else: - ftm = torch.fft.fft(f, axis=-1, norm="backward") - ftm = torch.fft.fftshift(ftm, dim=[-1])[:, m_offset:] - - flm = torch.zeros(samples.flm_shape(L), dtype=torch.complex128) - if sampling.lower() == "healpix": - flm[:, m_start_ind:] = torch.einsum("...tlm, ...tm -> ...lm", kernel, ftm) - else: - flm[:, m_start_ind:].real = torch.einsum( - "...tlm, ...tm -> ...lm", kernel, ftm.real - ) - flm[:, m_start_ind:].imag = torch.einsum( - "...tlm, ...tm -> ...lm", kernel, ftm.imag - ) - - if reality: - flm[:, :m_start_ind] = torch.flip( - (-1) ** (torch.arange(1, L) % 2) * torch.conj(flm[:, m_start_ind + 1 :]), - dims=[-1], - ) - - return flm * (-1) ** spin +forward_transform_torch = torch_wrapper.wrap_as_torch_function(forward_transform_jax) _inverse_functions = { @@ -624,7 +475,7 @@ def forward_transform_torch( } _kernel_functions = { - "numpy": partial(construct.spin_spherical_kernel, using_torch=False), + "numpy": construct.spin_spherical_kernel, "jax": construct.spin_spherical_kernel_jax, - "torch": partial(construct.spin_spherical_kernel, using_torch=True), + "torch": construct.spin_spherical_kernel_torch, } diff --git a/s2fft/precompute_transforms/wigner.py b/s2fft/precompute_transforms/wigner.py index 219ce71b..4222392e 100644 --- a/s2fft/precompute_transforms/wigner.py +++ b/s2fft/precompute_transforms/wigner.py @@ -2,12 +2,12 @@ import jax.numpy as jnp import numpy as np -import torch from jax import jit +from s2fft.precompute_transforms import construct from s2fft.sampling import so3_samples as samples from s2fft.utils import healpix_ffts as hp -from s2fft.utils import resampling, resampling_jax, resampling_torch +from s2fft.utils import resampling, resampling_jax, torch_wrapper def inverse( @@ -60,14 +60,21 @@ def inverse( :math:`\alpha` with :math:`\phi`. """ - if method == "numpy": - return inverse_transform(flmn, kernel, L, N, sampling, reality, nside) - elif method == "jax": - return inverse_transform_jax(flmn, kernel, L, N, sampling, reality, nside) - elif method == "torch": - return inverse_transform_torch(flmn, kernel, L, N, sampling, reality, nside) - else: + if method not in _inverse_functions: raise ValueError(f"Method {method} not recognised.") + common_kwargs = { + "L": L, + "N": N, + "sampling": sampling, + "reality": reality, + "nside": nside, + } + kernel = ( + _kernel_functions[method](forward=False, **common_kwargs) + if kernel is None + else kernel + ) + return _inverse_functions[method](flmn, kernel, **common_kwargs) def inverse_transform( @@ -174,7 +181,7 @@ def inverse_transform_jax( fnab = fnab.at[n_start_ind:, :, m_offset:].set( jnp.einsum( "...ntlm, ...nlm -> ...ntm", - kernel, + kernel.astype(fnab.dtype), flmn[n_start_ind:, :, :], optimize=True, ) @@ -211,82 +218,7 @@ def inverse_transform_jax( return jnp.conj(jnp.fft.fft2(fnab, axes=(-1, -3), norm="backward")) -def inverse_transform_torch( - flmn: torch.tensor, - kernel: torch.tensor, - L: int, - N: int, - sampling: str, - reality: bool, - nside: int, -) -> torch.tensor: - r""" - Compute the inverse Wigner transform, i.e. inverse Fourier transform on - :math:`SO(3)`. - - Args: - flmn (torch.tensor): Wigner coefficients with shape :math:`[2N-1, L, 2L-1]`. - - kernel (torch.tensor): Wigner-d kernel. - - L (int): Harmonic band-limit. - - N (int): Directional band-limit. - - sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. - - reality (bool, optional): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - nside (int): HEALPix Nside resolution parameter. Only required - if sampling="healpix". - - Returns: - torch.tensor: Pixel-space coefficients. - - """ - m_offset = 1 if sampling in ["mwss", "healpix"] else 0 - n_start_ind = N - 1 if reality else 0 - - fnab = torch.zeros( - samples.fnab_shape(L, N, sampling, nside), dtype=torch.complex128 - ) - if sampling.lower() == "healpix": - fnab[n_start_ind:, :, m_offset:] = torch.einsum( - "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:, :, :] - ) - else: - fnab[n_start_ind:, :, m_offset:].real = torch.einsum( - "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:, :, :].real - ) - fnab[n_start_ind:, :, m_offset:].imag = torch.einsum( - "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:, :, :].imag - ) - - if sampling.lower() in "healpix": - f = torch.zeros(samples.f_shape(L, N, sampling, nside), dtype=torch.complex128) - for n in range(n_start_ind - N + 1, N): - ind = N - 1 + n - f[ind] = hp.healpix_ifft(fnab[ind], L, nside, "torch") - if reality: - return torch.fft.irfft(f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward") - else: - return torch.fft.ifft( - torch.fft.ifftshift(f, dim=[-2]), axis=-2, norm="forward" - ) - - else: - if reality: - fnab = torch.fft.ifft( - torch.fft.ifftshift(fnab, dim=[-1]), axis=-1, norm="forward" - ) - return torch.fft.irfft( - fnab[n_start_ind:], 2 * N - 1, axis=-3, norm="forward" - ) - else: - fnab = torch.fft.ifftshift(fnab, dim=[-1, -3]) - return torch.fft.ifft2(fnab, dim=[-1, -3], norm="forward") +inverse_transform_torch = torch_wrapper.wrap_as_torch_function(inverse_transform_jax) def forward( @@ -339,14 +271,21 @@ def forward( :math:`\alpha` with :math:`\phi`. """ - if method == "numpy": - return forward_transform(f, kernel, L, N, sampling, reality, nside) - elif method == "jax": - return forward_transform_jax(f, kernel, L, N, sampling, reality, nside) - elif method == "torch": - return forward_transform_torch(f, kernel, L, N, sampling, reality, nside) - else: + if method not in _forward_functions: raise ValueError(f"Method {method} not recognised.") + common_kwargs = { + "L": L, + "N": N, + "sampling": sampling, + "reality": reality, + "nside": nside, + } + kernel = ( + _kernel_functions[method](forward=True, **common_kwargs) + if kernel is None + else kernel + ) + return _forward_functions[method](f, kernel, **common_kwargs) def forward_transform( @@ -500,7 +439,9 @@ def forward_transform_jax( flmn = jnp.zeros(samples.flmn_shape(L, N), dtype=jnp.complex128) flmn = flmn.at[n_start_ind:].set( - jnp.einsum("...ntlm, ...ntm -> ...nlm", kernel, fban, optimize=True) + jnp.einsum( + "...ntlm, ...ntm -> ...nlm", kernel.astype(flmn.dtype), fban, optimize=True + ) ) if reality: flmn = flmn.at[:n_start_ind].set( @@ -526,96 +467,24 @@ def forward_transform_jax( return flmn -def forward_transform_torch( - f: torch.tensor, - kernel: torch.tensor, - L: int, - N: int, - sampling: str, - reality: bool, - nside: int, -) -> torch.tensor: - r""" - Compute the forward Wigner transform, i.e. Fourier transform on - :math:`SO(3)`. - - Args: - f (torch.tensor): Signal on the sphere. - - kernel (torch.tensor): Wigner-d kernel. - - L (int): Harmonic band-limit. - - N (int): Directional band-limit. - - sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. - - reality (bool, optional): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - nside (int): HEALPix Nside resolution parameter. Only required - if sampling="healpix". +forward_transform_torch = torch_wrapper.wrap_as_torch_function(forward_transform_jax) - Returns: - torch.tensor: Wigner space coefficients. - """ - n_start_ind = N - 1 if reality else 0 +_inverse_functions = { + "numpy": inverse_transform, + "jax": inverse_transform_jax, + "torch": inverse_transform_torch, +} - ax = -2 if sampling.lower() == "healpix" else -3 - if reality: - fban = torch.fft.rfft(torch.real(f), axis=ax, norm="backward") - else: - fban = torch.fft.fftshift(torch.fft.fft(f, axis=ax, norm="backward"), dim=ax) - spins = -torch.arange(n_start_ind - N + 1, N) - if sampling.lower() == "mw": - fban = resampling_torch.mw_to_mwss(fban, L, spins) +_forward_functions = { + "numpy": forward_transform, + "jax": forward_transform_jax, + "torch": forward_transform_torch, +} - if sampling.lower() in ["mw", "mwss"]: - sampling = "mwss" - fban = resampling_torch.upsample_by_two_mwss(fban, L, spins) - - m_offset = 1 if sampling in ["mwss", "healpix"] else 0 - - if sampling.lower() in "healpix": - temp = torch.zeros( - samples.fnab_shape(L, N, sampling, nside), dtype=torch.complex128 - ) - for n in range(n_start_ind - N + 1, N): - ind = n if reality else N - 1 + n - temp[N - 1 + n] = hp.healpix_fft(fban[ind], L, nside, "torch") - fban = temp[n_start_ind:, :, m_offset:] - - else: - fban = torch.fft.fft(fban, axis=-1, norm="backward") - fban = torch.fft.fftshift(fban, dim=[-1])[:, :, m_offset:] - - flmn = torch.zeros(samples.flmn_shape(L, N), dtype=torch.complex128) - - if sampling.lower() == "healpix": - flmn[n_start_ind:] = torch.einsum("...ntlm, ...ntm -> ...nlm", kernel, fban) - else: - flmn[n_start_ind:].real = torch.einsum( - "...ntlm, ...ntm -> ...nlm", kernel, fban.real - ) - flmn[n_start_ind:].imag = torch.einsum( - "...ntlm, ...ntm -> ...nlm", kernel, fban.imag - ) - if reality: - flmn[:n_start_ind] = torch.conj( - torch.flip(flmn[n_start_ind + 1 :], dims=(-1, -3)) - ) - flmn[:n_start_ind] = torch.einsum( - "...nlm,...m->...nlm", - flmn[:n_start_ind], - (-1) ** abs(torch.arange(-L + 1, L)), - ) - flmn[:n_start_ind] = torch.einsum( - "...nlm,...n->...nlm", - flmn[:n_start_ind], - (-1) ** abs(torch.arange(-N + 1, 0)), - ) - - return flmn +_kernel_functions = { + "numpy": construct.wigner_kernel, + "jax": construct.wigner_kernel_jax, + "torch": construct.wigner_kernel_torch, +} diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 52f1db61..7d3ff051 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -15,6 +15,7 @@ quadrature_jax, resampling, resampling_jax, + torch_wrapper, ) @@ -47,8 +48,8 @@ def inverse( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax", "jax_ssht", "jax_healpy"}. - Defaults to "numpy". + method (str, optional): Execution mode in {"numpy", "jax", "jax_cuda", + "jax_ssht", "jax_healpy"}. Defaults to "numpy". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -82,28 +83,31 @@ def inverse( recover acceleration by the number of devices. """ - if spin >= 8 and method in ["numpy", "jax"]: + if method not in _inverse_functions: + raise ValueError(f"Method {method} not recognised.") + + if spin >= 8 and method in ("numpy", "jax", "jax_cuda", "torch"): raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") - if method == "numpy": - return inverse_numpy(flm, L, spin, nside, sampling, reality, precomps, L_lower) - elif method == "jax": - return inverse_jax( - flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower - ) - elif method == "jax_ssht": + inverse_kwargs = {"flm": flm, "L": L} + if method in ("numpy", "jax", "jax_cuda", "torch"): + inverse_kwargs.update(sampling=sampling, precomps=precomps, L_lower=L_lower) + if method in ("jax", "jax_cuda", "torch"): + inverse_kwargs["spmd"] = spmd + if method == "jax_healpy": + if sampling.lower() != "healpix": + raise ValueError("Healpy only supports healpix sampling.") + else: + inverse_kwargs.update(spin=spin, reality=reality) + if method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") ssht_sampling = ["mw", "mwss", "dh", "gl"].index(sampling.lower()) - return c_sph.ssht_inverse(flm, L, spin, reality, ssht_sampling, _ssht_backend) - elif method == "jax_healpy": - if sampling.lower() != "healpix": - raise ValueError("Healpy only supports healpix sampling.") - return c_sph.healpy_inverse(flm, L, nside) + inverse_kwargs.update(ssht_sampling=ssht_sampling, _ssht_backend=_ssht_backend) else: - raise ValueError( - f"Implementation {method} not recognised. Should be either numpy or jax." - ) + inverse_kwargs["nside"] = nside + + return _inverse_functions[method](**inverse_kwargs) def inverse_numpy( @@ -205,7 +209,7 @@ def inverse_numpy( return np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward") -@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8)) +@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8, 9)) def inverse_jax( flm: jnp.ndarray, L: int, @@ -216,6 +220,7 @@ def inverse_jax( precomps: List = None, spmd: bool = False, L_lower: int = 0, + use_healpix_custom_primitive: bool = False, ) -> jnp.ndarray: r""" Compute the inverse spin-spherical harmonic transform (JAX). @@ -251,6 +256,12 @@ def inverse_jax( L_lower (int, optional): Harmonic lower-bound. Transform will only be computed for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. + use_healpix_custom_primitive (bool, optional): Whether to use a custom CUDA + primitive for computing HEALPix fast Fourier transform when `sampling = + "healpix"` and running on a CUDA compatible GPU device. Using a custom + primitive reduces long compilation times when just-in-time compiling. + Defaults to `False`. + Returns: jnp.ndarray: Signal on the sphere. @@ -326,13 +337,19 @@ def f_bwd(res, gtm): jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1) ) if sampling.lower() == "healpix": - return hp.healpix_ifft(ftm, L, nside, "jax") + if use_healpix_custom_primitive: + return hp.healpix_ifft(ftm, L, nside, "cuda") + else: + return hp.healpix_ifft(ftm, L, nside, "jax") else: ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1)) f = jnp.conj(jnp.fft.fft(ftm, axis=1, norm="backward")) return jnp.real(f) if reality else f +inverse_torch = torch_wrapper.wrap_as_torch_function(inverse_jax) + + def forward( f: np.ndarray, L: int, @@ -363,8 +380,8 @@ def forward( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax", "jax_ssht", "jax_healpy"}. - Defaults to "numpy". + method (str, optional): Execution mode in {"numpy", "jax", "jax_cuda", + jax_ssht", "jax_healpy"}. Defaults to "numpy". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -406,50 +423,46 @@ def forward( recover acceleration by the number of devices. """ - if spin >= 8 and method in ["numpy", "jax"]: + if method not in _forward_functions: + raise ValueError(f"Method {method} not recognised.") + + if spin >= 8 and method in ("numpy", "jax", "jax_cuda", "torch"): raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") if iter is None: iter = 3 if sampling.lower() == "healpix" and method == "jax_healpy" else 0 - if method in {"numpy", "jax", "cuda"}: - common_kwargs = { - "L": L, - "spin": spin, - "nside": nside, - "sampling": sampling, - "reality": reality, - "L_lower": L_lower, - } - forward_kwargs = {**common_kwargs, "precomps": precomps} - inverse_kwargs = common_kwargs - if method in {"jax", "cuda"}: - forward_kwargs["spmd"] = spmd - forward_kwargs["use_healpix_custom_primitive"] = method == "cuda" - inverse_kwargs["method"] = "jax" - inverse_kwargs["spmd"] = spmd - forward_function = forward_jax - else: - inverse_kwargs["method"] = "numpy" - forward_function = forward_numpy - return iterative_refinement.forward_with_iterative_refinement( - f=f, - n_iter=iter, - forward_function=partial(forward_function, **forward_kwargs), - backward_function=partial(inverse, **inverse_kwargs), - ) - elif method == "jax_ssht": + + forward_kwargs = {"f": f, "L": L} + if method in ("numpy", "jax", "jax_cuda", "torch"): + forward_kwargs.update(sampling=sampling, precomps=precomps, L_lower=L_lower) + if method in ("jax", "jax_cuda", "torch"): + forward_kwargs["spmd"] = spmd + if method == "jax_healpy": + if sampling.lower() != "healpix": + raise ValueError("Healpy only supports healpix sampling.") + forward_kwargs["iter"] = iter + else: + forward_kwargs.update(spin=spin, reality=reality) + if method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") ssht_sampling = ["mw", "mwss", "dh", "gl"].index(sampling.lower()) - return c_sph.ssht_forward(f, L, spin, reality, ssht_sampling, _ssht_backend) - elif method == "jax_healpy": - if sampling.lower() != "healpix": - raise ValueError("Healpy only supports healpix sampling.") - return c_sph.healpy_forward(f, L, nside, iter) + forward_kwargs.update(ssht_sampling=ssht_sampling, _ssht_backend=_ssht_backend) else: - raise ValueError( - f"Implementation {method} not recognised. Should be either numpy or jax." + forward_kwargs["nside"] = nside + + if iter > 0 and method != "jax_healpy": + f = forward_kwargs.pop("f") + inverse_kwargs = forward_kwargs.copy() + inverse_kwargs.pop("precomps") + return iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial(_forward_functions[method], **forward_kwargs), + backward_function=partial(_inverse_functions[method], **inverse_kwargs), ) + else: + return _forward_functions[method](**forward_kwargs) def forward_numpy( @@ -624,10 +637,10 @@ def forward_jax( for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. use_healpix_custom_primitive (bool, optional): Whether to use a custom CUDA - primitive for computing HEALPix fast fourier transform when `sampling = - "healpix"` and running on a cuda compatible gpu device. using a custom - primitive reduces long compilation times when jit compiling. defaults to - `False`. + primitive for computing HEALPix fast Fourier transform when `sampling = + "healpix"` and running on a CUDA compatible GPU device. Using a custom + primitive reduces long compilation times when just-in-time compiling. + Defaults to `False`. Returns: jnp.ndarray: Spherical harmonic coefficients @@ -741,3 +754,25 @@ def f_bwd(res, glm): flm = jnp.where(indices < abs(spin), jnp.zeros_like(flm), flm[..., :]) return flm * (-1) ** jnp.abs(spin) + + +forward_torch = torch_wrapper.wrap_as_torch_function(forward_jax) + + +_inverse_functions = { + "numpy": inverse_numpy, + "jax": inverse_jax, + "jax_cuda": partial(inverse_jax, use_healpix_custom_primitive=True), + "jax_ssht": c_sph.ssht_inverse, + "jax_healpy": c_sph.healpy_inverse, + "torch": inverse_torch, +} + +_forward_functions = { + "numpy": forward_numpy, + "jax": forward_jax, + "jax_cuda": partial(forward_jax, use_healpix_custom_primitive=True), + "jax_ssht": c_sph.ssht_forward, + "jax_healpy": c_sph.healpy_forward, + "torch": forward_torch, +} diff --git a/s2fft/transforms/wigner.py b/s2fft/transforms/wigner.py index 7c195737..a9126b24 100644 --- a/s2fft/transforms/wigner.py +++ b/s2fft/transforms/wigner.py @@ -8,6 +8,7 @@ import s2fft from s2fft.sampling import so3_samples as samples from s2fft.transforms import c_backend_spherical as c_sph +from s2fft.utils import torch_wrapper def inverse( @@ -81,21 +82,30 @@ def inverse( IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ - if N >= 8 and method in ["numpy", "jax"]: + if method not in _inverse_functions: + raise ValueError(f"Method {method} not recognised.") + + if N >= 8 and method in ("numpy", "jax", "torch"): raise Warning("Recursive transform may provide lower precision beyond N ~ 8") - if method == "numpy": - return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower) - elif method == "jax": - return inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower) - elif method == "jax_ssht": + inverse_kwargs = { + "flmn": flmn, + "L": L, + "N": N, + "L_lower": L_lower, + "sampling": sampling, + "reality": reality, + } + + if method in ("jax", "numpy", "torch"): + inverse_kwargs.update(nside=nside, precomps=precomps) + + if method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") - return inverse_jax_ssht(flmn, L, N, L_lower, sampling, reality, _ssht_backend) - else: - raise ValueError( - f"Implementation {method} not recognised. Should be either numpy or jax." - ) + inverse_kwargs["_ssht_backend"] = _ssht_backend + + return _inverse_functions[method](**inverse_kwargs) def inverse_numpy( @@ -279,6 +289,9 @@ def func(flm, spin, p0, p1, p2, p3, p4): return f +inverse_torch = torch_wrapper.wrap_as_torch_function(inverse_jax) + + def inverse_jax_ssht( flmn: jnp.ndarray, L: int, @@ -401,21 +414,30 @@ def forward( IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ - if N >= 8 and method in ["numpy", "jax"]: + if method not in _inverse_functions: + raise ValueError(f"Method {method} not recognised.") + + if N >= 8 and method in ("numpy", "jax", "torch"): raise Warning("Recursive transform may provide lower precision beyond N ~ 8") - if method == "numpy": - return forward_numpy(f, L, N, nside, sampling, reality, precomps, L_lower) - elif method == "jax": - return forward_jax(f, L, N, nside, sampling, reality, precomps, L_lower) - elif method == "jax_ssht": + forward_kwargs = { + "f": f, + "L": L, + "N": N, + "L_lower": L_lower, + "sampling": sampling, + "reality": reality, + } + + if method in ("jax", "numpy", "torch"): + forward_kwargs.update(nside=nside, precomps=precomps) + + if method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") - return forward_jax_ssht(f, L, N, L_lower, sampling, reality, _ssht_backend) - else: - raise ValueError( - f"Implementation {method} not recognised. Should be either numpy or jax." - ) + forward_kwargs["_ssht_backend"] = _ssht_backend + + return _forward_functions[method](**forward_kwargs) def forward_numpy( @@ -624,6 +646,9 @@ def func(fba, spin, p0, p1, p2, p3, p4): return flmn +forward_torch = torch_wrapper.wrap_as_torch_function(forward_jax) + + def forward_jax_ssht( f: jnp.ndarray, L: int, @@ -805,3 +830,18 @@ def _fban_to_f(fban: jnp.ndarray, L: int, N: int, reality: bool = False) -> jnp. else: f = jnp.fft.ifft(jnp.fft.ifftshift(fban, axes=-3), axis=-3, norm="forward") return f + + +_inverse_functions = { + "numpy": inverse_numpy, + "jax": inverse_jax, + "jax_ssht": inverse_jax_ssht, + "torch": inverse_torch, +} + +_forward_functions = { + "numpy": forward_numpy, + "jax": forward_jax, + "jax_ssht": forward_jax_ssht, + "torch": forward_torch, +} diff --git a/s2fft/utils/healpix_ffts.py b/s2fft/utils/healpix_ffts.py index 075a35ce..b5a5efc8 100644 --- a/s2fft/utils/healpix_ffts.py +++ b/s2fft/utils/healpix_ffts.py @@ -3,7 +3,6 @@ import jax.numpy as jnp import jaxlib.mlir.ir as ir import numpy as np -import torch from jax import jit, vmap # did not find promote_dtypes_complex outside _src @@ -14,6 +13,7 @@ from s2fft.sampling import s2_samples as samples from s2fft.utils.jax_primitive import register_primitive +from s2fft.utils.torch_wrapper import wrap_as_torch_function def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: @@ -79,39 +79,7 @@ def spectral_folding_jax(fm: jnp.ndarray, nphi: int, L: int) -> jnp.ndarray: ) -def spectral_folding_torch(fm: torch.tensor, nphi: int, L: int) -> torch.tensor: - """ - Folds higher frequency Fourier coefficients back onto lower frequency - coefficients, i.e. aliasing high frequencies. Torch specific implementation of - :func:`~spectral_folding`. - - Args: - fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t. - - nphi (int): Total number of pixel space phi samples for latitude t. - - L (int): Harmonic band-limit. - - Returns: - torch.tensor: Lower resolution set of aliased Fourier coefficients. - - """ - slice_start = L - nphi // 2 - slice_stop = slice_start + nphi - ftm_slice = fm[slice_start:slice_stop] - - ftm_slice = ftm_slice.put_( - -torch.arange(1, L - nphi // 2 + 1) % nphi, - fm[slice_start - torch.arange(1, L - nphi // 2 + 1)], - accumulate=True, - ) - ftm_slice = ftm_slice.put_( - torch.arange(L - nphi // 2) % nphi, - fm[slice_stop + torch.arange(L - nphi // 2)], - accumulate=True, - ) - - return ftm_slice +spectral_folding_torch = wrap_as_torch_function(spectral_folding_jax) def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: @@ -174,29 +142,9 @@ def spectral_periodic_extension_jax(fm: jnp.ndarray, L: int) -> jnp.ndarray: ) -def spectral_periodic_extension_torch(fm: torch.tensor, L: int) -> torch.tensor: - """ - Extends lower frequency Fourier coefficients onto higher frequency - coefficients, i.e. imposed periodicity in Fourier space. Based on - :func:`~spectral_periodic_extension`. - - Args: - fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t. - - L (int): Harmonic band-limit. - - Returns: - torch.tensor: Higher resolution set of periodic Fourier coefficients. - - """ - nphi = fm.shape[0] - return torch.concatenate( - ( - fm[-torch.arange(L - nphi // 2, 0, -1) % nphi], - fm, - fm[torch.arange(L - (nphi + 1) // 2) % nphi], - ) - ) +spectral_periodic_extension_torch = wrap_as_torch_function( + spectral_periodic_extension_jax +) def healpix_fft( @@ -231,16 +179,9 @@ def healpix_fft( np.ndarray: Array of Fourier coefficients for all latitudes. """ - if method.lower() == "numpy": - return healpix_fft_numpy(f, L, nside, reality) - elif method.lower() == "jax": - return healpix_fft_jax(f, L, nside, reality) - elif method.lower() == "cuda": - return healpix_fft_cuda(f, L, nside, reality) - elif method.lower() == "torch": - return healpix_fft_torch(f, L, nside, reality) - else: + if method not in _healpix_fft_functions: raise ValueError(f"Method {method} not recognised.") + return _healpix_fft_functions[method](f, L, nside, reality) def healpix_fft_numpy(f: np.ndarray, L: int, nside: int, reality: bool) -> np.ndarray: @@ -350,49 +291,7 @@ def f_chunks_to_ftm_rows(f_chunks, nphi): ) -def healpix_fft_torch( - f: torch.tensor, L: int, nside: int, reality: bool -) -> torch.tensor: - """ - Computes the Forward Fast Fourier Transform with spectral back-projection - in the polar regions to manually enforce Fourier periodicity. Torch specific - implementation of :func:`~healpix_fft_numpy`. - - Args: - f (torch.tensor): HEALPix pixel-space array. - - L (int): Harmonic band-limit. - - nside (int): HEALPix Nside resolution parameter. - - reality (bool): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - Returns: - torch.tensor: Array of Fourier coefficients for all latitudes. - - """ - index = 0 - ftm = torch.zeros(samples.ftm_shape(L, "healpix", nside), dtype=torch.complex128) - ntheta = ftm.shape[0] - for t in range(ntheta): - nphi = samples.nphi_ring(t, nside) - if reality and nphi == 2 * L: - fm_chunk = torch.zeros(nphi, dtype=torch.complex128) - fm_chunk[nphi // 2 :] = torch.fft.rfft( - torch.real(f[index : index + nphi]), norm="backward" - )[:-1] - else: - fm_chunk = torch.fft.fftshift( - torch.fft.fft(f[index : index + nphi], norm="backward") - ) - ftm[t] = ( - fm_chunk - if nphi == 2 * L - else spectral_periodic_extension_torch(fm_chunk, L) - ) - index += nphi - return ftm +healpix_fft_torch = wrap_as_torch_function(healpix_fft_jax) def healpix_ifft( @@ -413,7 +312,7 @@ def healpix_ifft( nside (int): HEALPix Nside resolution parameter. - method (str, optional): Evaluation method in {"numpy", "jax", "torch"}. + method (str, optional): Evaluation method in {"numpy", "jax", "torch", "cuda"}. Defaults to "numpy". reality (bool): Whether the signal on the sphere is real. If so, @@ -421,23 +320,16 @@ def healpix_ifft( Defaults to False. Raises: - ValueError: Deployment method not in {"numpy", "jax", "torch"}. + ValueError: Deployment method not in {"numpy", "jax", "torch", "cuda"}. Returns: np.ndarray: HEALPix pixel-space array. """ assert L >= 2 * nside - if method.lower() == "numpy": - return healpix_ifft_numpy(ftm, L, nside, reality) - elif method.lower() == "jax": - return healpix_ifft_jax(ftm, L, nside, reality) - elif method.lower() == "cuda": - return healpix_ifft_cuda(ftm, L, nside, reality) - elif method.lower() == "torch": - return healpix_ifft_torch(ftm, L, nside, reality) - else: + if method not in _healpix_ifft_functions: raise ValueError(f"Method {method} not recognised.") + return _healpix_ifft_functions[method](ftm, L, nside, reality) def healpix_ifft_numpy( @@ -531,46 +423,7 @@ def ftm_rows_to_f_chunks(ftm_rows, nphi): ) -def healpix_ifft_torch( - ftm: torch.tensor, L: int, nside: int, reality: bool -) -> torch.tensor: - """ - Computes the Inverse Fast Fourier Transform with spectral folding in the polar - regions to mitigate aliasing. Torch specific implementation of - :func:`~healpix_ifft_numpy`. - - Args: - ftm (torch.tensor): Array of Fourier coefficients for all latitudes. - - L (int): Harmonic band-limit. - - nside (int): HEALPix Nside resolution parameter. - - reality (bool): Whether the signal on the sphere is real. If so, - conjugate symmetry is exploited to reduce computational costs. - - Returns: - torch.tensor: HEALPix pixel-space array. - - """ - f = torch.zeros( - samples.f_shape(sampling="healpix", nside=nside), dtype=torch.complex128 - ) - ntheta = ftm.shape[0] - index = 0 - for t in range(ntheta): - nphi = samples.nphi_ring(t, nside) - fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_torch(ftm[t], nphi, L) - if reality and nphi == 2 * L: - f[index : index + nphi] = torch.fft.irfft( - fm_chunk[nphi // 2 :], nphi, norm="forward" - ) - else: - f[index : index + nphi] = torch.fft.ifft( - torch.fft.ifftshift(fm_chunk), norm="forward" - ) - index += nphi - return f +healpix_ifft_torch = wrap_as_torch_function(healpix_ifft_jax) def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray: @@ -823,3 +676,18 @@ def healpix_ifft_cuda( return _healpix_fft_cuda_primitive.bind( ftm, L=L, nside=nside, reality=reality, fft_type="backward", norm=norm ) + + +_healpix_fft_functions = { + "numpy": healpix_fft_numpy, + "jax": healpix_fft_jax, + "cuda": healpix_fft_cuda, + "torch": healpix_fft_torch, +} + +_healpix_ifft_functions = { + "numpy": healpix_ifft_numpy, + "jax": healpix_ifft_jax, + "cuda": healpix_ifft_cuda, + "torch": healpix_ifft_torch, +} diff --git a/s2fft/utils/quadrature_jax.py b/s2fft/utils/quadrature_jax.py index 3b18bff5..bc1426b2 100644 --- a/s2fft/utils/quadrature_jax.py +++ b/s2fft/utils/quadrature_jax.py @@ -1,13 +1,13 @@ -from functools import partial +from functools import partial as _partial import jax import jax.numpy as jnp -from jax import jit +from jax import jit as _jit from s2fft.sampling import s2_samples as samples -@partial(jit, static_argnums=(0, 1, 2)) +@_partial(_jit, static_argnums=(0, 1, 2)) def quad_weights_transform( L: int, sampling: str = "mwss", nside: int = 0 ) -> jnp.ndarray: @@ -53,7 +53,7 @@ def quad_weights_transform( raise ValueError(f"Sampling scheme sampling={sampling} not supported") -@partial(jit, static_argnums=(0, 1, 2)) +@_partial(_jit, static_argnums=(0, 1, 2)) def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp.ndarray: r""" Compute quadrature weights for :math:`\theta` and :math:`\phi` @@ -99,7 +99,7 @@ def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp. raise ValueError(f"Sampling scheme sampling={sampling} not implemented") -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_hp(nside: int) -> jnp.ndarray: r""" Compute HEALPix quadrature weights for :math:`\theta` and :math:`\phi` @@ -123,7 +123,7 @@ def quad_weights_hp(nside: int) -> jnp.ndarray: return jnp.ones(rings, dtype=jnp.float64) * 4 * jnp.pi / npix -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_gl(L: int) -> jnp.ndarray: r""" Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. @@ -174,7 +174,7 @@ def body(arg): return weights * 2 * jnp.pi / (2 * L - 1) -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_dh(L: int) -> jnp.ndarray: r""" Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. @@ -193,7 +193,7 @@ def quad_weights_dh(L: int) -> jnp.ndarray: return q * 2 * jnp.pi / (2 * L - 1) -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def quad_weight_dh_theta_only(theta: float, L: int) -> float: r""" Compute DH quadrature weight for :math:`\theta` integration (only), for given @@ -217,7 +217,7 @@ def quad_weight_dh_theta_only(theta: float, L: int) -> float: return w -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_mw(L: int) -> jnp.ndarray: r""" Compute MW quadrature weights for :math:`\theta` and :math:`\phi` integration. @@ -236,7 +236,7 @@ def quad_weights_mw(L: int) -> jnp.ndarray: return quad_weights_mw_theta_only(L) * 2 * jnp.pi / (2 * L - 1) -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_mwss(L: int) -> jnp.ndarray: r""" Compute MWSS quadrature weights for :math:`\theta` and :math:`\phi` integration. @@ -255,7 +255,7 @@ def quad_weights_mwss(L: int) -> jnp.ndarray: return quad_weights_mwss_theta_only(L) * 2 * jnp.pi / (2 * L) -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_mwss_theta_only(L: int) -> jnp.ndarray: r""" Compute MWSS quadrature weights for :math:`\theta` integration (only). @@ -282,7 +282,7 @@ def quad_weights_mwss_theta_only(L: int) -> jnp.ndarray: return q -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def quad_weights_mw_theta_only(L: int) -> jnp.ndarray: r""" Compute MW quadrature weights for :math:`\theta` integration (only). @@ -309,7 +309,7 @@ def quad_weights_mw_theta_only(L: int) -> jnp.ndarray: return q -@partial(jit, static_argnums=(0)) +@_partial(_jit, static_argnums=(0)) def mw_weights(m: int) -> float: r""" Compute MW weights given as a function of index m. diff --git a/s2fft/utils/quadrature_torch.py b/s2fft/utils/quadrature_torch.py index f669a85b..1b5c00d8 100644 --- a/s2fft/utils/quadrature_torch.py +++ b/s2fft/utils/quadrature_torch.py @@ -1,320 +1,6 @@ -import torch +from s2fft.utils import quadrature_jax as _quadrature_jax +from s2fft.utils import torch_wrapper as _torch_wrapper -from s2fft.sampling import s2_samples as samples - - -def quad_weights_transform( - L: int, sampling: str = "mwss", nside: int = 0 -) -> torch.tensor: - r""" - Compute quadrature weights for :math:`\theta` and :math:`\phi` - integration *to use in transform* for various sampling schemes. Torch implementation of - :func:`~s2fft.quadrature.quad_weights_transform`. - - Quadrature weights to use in transform for MWSS correspond to quadrature weights - are twice the base resolution, i.e. 2 * L. - - Args: - L (int): Harmonic band-limit. - - sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". - - nside (int, optional): HEALPix Nside resolution parameter. Only required - if sampling="healpix". Defaults to None. - - Raises: - ValueError: Invalid sampling scheme. - - Returns: - torch.tensor: Quadrature weights *to use in transform* for sampling scheme for - each :math:`\theta` (weights are identical as :math:`\phi` varies for given - :math:`\theta`). - - """ - if sampling.lower() == "mwss": - return quad_weights_mwss_theta_only(2 * L) * 2 * torch.pi / (2 * L) - - elif sampling.lower() == "dh": - return quad_weights_dh(L) - - elif sampling.lower() == "gl": - return quad_weights_gl(L) - - elif sampling.lower() == "healpix": - return quad_weights_hp(nside) - - else: - raise ValueError(f"Sampling scheme sampling={sampling} not supported") - - -def quad_weights( - L: int = None, sampling: str = "mw", nside: int = None -) -> torch.tensor: - r""" - Compute quadrature weights for :math:`\theta` and :math:`\phi` - integration for various sampling schemes. Torch implementation of - :func:`~s2fft.quadrature.quad_weights`. - - Args: - L (int, optional): Harmonic band-limit. Required if sampling not healpix. - Defaults to None. - - sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - - spin (int, optional): Harmonic spin. Defaults to 0. - - nside (int, optional): HEALPix Nside resolution parameter. Only required - if sampling="healpix". Defaults to None. - - Raises: - ValueError: Invalid sampling scheme. - - Returns: - torch.tensor: Quadrature weights for sampling scheme for each :math:`\theta` - (weights are identical as :math:`\phi` varies for given :math:`\theta`). - - """ - if sampling.lower() == "mw": - return quad_weights_mw(L) - - elif sampling.lower() == "mwss": - return quad_weights_mwss(L) - - elif sampling.lower() == "dh": - return quad_weights_dh(L) - - elif sampling.lower() == "gl": - return quad_weights_gl(L) - - elif sampling.lower() == "healpix": - return quad_weights_hp(nside) - - else: - raise ValueError(f"Sampling scheme sampling={sampling} not implemented") - - -def quad_weights_hp(nside: int) -> torch.tensor: - r""" - Compute HEALPix quadrature weights for :math:`\theta` and :math:`\phi` - integration. Torch implementation of :func:`s2fft.quadrature.quad_weights_hp`. - - Note: - HEALPix weights are identical for all pixels. Nevertheless, an array of - weights is returned (with identical values) for consistency of interface - across other sampling schemes. - - Args: - nside (int): HEALPix Nside resolution parameter. - - Returns: - torch.tensor: Weights computed for each :math:`\theta` (all weights in array are - identical). - - """ - npix = 12 * nside**2 - rings = samples.ntheta(sampling="healpix", nside=nside) - return torch.ones(rings, dtype=torch.float64) * 4 * torch.pi / npix - - -def quad_weights_gl(L: int) -> torch.tensor: - r""" - Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. - - Args: - L (int): Harmonic band-limit. - - Returns: - np.ndarray: Weights computed for each :math:`\theta` (weights are identical - as :math:`\phi` varies for given :math:`\theta`). - - """ - x1, x2 = -1.0, 1.0 - ntheta = samples.ntheta(L, "gl") - weights = torch.zeros(ntheta, dtype=torch.float64) - - m = int((L + 1) / 2) - x1 = 0.5 * (x2 - x1) - - i = torch.arange(1, m + 1) - z = torch.cos(torch.pi * (i.type(torch.float64) - 0.25) / (L + 0.5)) - z1 = 2.0 - while torch.max(torch.abs(z - z1)) > 1e-14: - p1 = 1.0 - p2 = 0.0 - for j in range(1, L + 1): - p3 = p2 - p2 = p1 - p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j - pp = L * (z * p1 - p2) / (z * z - 1.0) - z1 = z - z = z1 - p1 / pp - - weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) - weights[L + 1 - i - 1] = weights[i - 1] - - return weights * 2 * torch.pi / (2 * L - 1) - - -def quad_weights_dh(L: int) -> torch.tensor: - r""" - Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. - Torch implementation of :func:`s2fft.quadrature.quad_weights_dh`. - - Args: - L (int): Harmonic band-limit. - - Returns: - torch.tensor: Weights computed for each :math:`\theta` (weights are identical - as :math:`\phi` varies for given :math:`\theta`). - - """ - q = quad_weight_dh_theta_only(samples.thetas(L, sampling="dh"), L) - - return q * 2 * torch.pi / (2 * L - 1) - - -def quad_weight_dh_theta_only(theta: float, L: int) -> float: - r""" - Compute DH quadrature weight for :math:`\theta` integration (only), for given - :math:`\theta`. Torch implementation of :func:`s2fft.quadrature.quad_weights_dh_theta_only`. - - Args: - theta (float): :math:`\theta` angle for which to compute weight. - - L (int): Harmonic band-limit. - - Returns: - float: Weight computed for each :math:`\theta`. - - """ - w = 0.0 - for k in range(0, L): - w += torch.sin((2 * k + 1) * torch.from_numpy(theta)) / (2 * k + 1) - - w *= 2 / L * torch.sin(torch.from_numpy(theta)) - - return w - - -def quad_weights_mw(L: int) -> torch.tensor: - r""" - Compute MW quadrature weights for :math:`\theta` and :math:`\phi` integration. - Torch implementation of :func:`s2fft.quadrature.quad_weights_mw`. - - Args: - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Weights computed for each :math:`\theta` (weights are identical - as :math:`\phi` varies for given :math:`\theta`). - - """ - return quad_weights_mw_theta_only(L) * 2 * torch.pi / (2 * L - 1) - - -def quad_weights_mwss(L: int) -> torch.tensor: - r""" - Compute MWSS quadrature weights for :math:`\theta` and :math:`\phi` integration. - JAX implementation of :func:`s2fft.quadrature.quad_weights_mwss`. - - Args: - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Weights computed for each :math:`\theta` (weights are identical - as :math:`\phi` varies for given :math:`\theta`). - - """ - return quad_weights_mwss_theta_only(L) * 2 * torch.pi / (2 * L) - - -def quad_weights_mwss_theta_only(L: int) -> torch.tensor: - r""" - Compute MWSS quadrature weights for :math:`\theta` integration (only). - Torch implementation of :func:`s2fft.quadrature.quad_weights_mwss_theta_only`. - - Args: - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - np.ndarray: Weights computed for each :math:`\theta`. - - """ - w = torch.zeros(2 * L, dtype=torch.complex128) - - for i in range(-(L - 1) + 1, L + 1): - w[i + L - 1] = mw_weights(i - 1) - - wr = torch.real(torch.fft.fft(torch.fft.ifftshift(w), norm="backward")) / (2 * L) - q = wr[: L + 1] - q[1:L] += torch.flip(wr, dims=[0])[: L - 1] - - return q - - -def quad_weights_mw_theta_only(L: int) -> torch.tensor: - r""" - Compute MW quadrature weights for :math:`\theta` integration (only). - Torch implementation of :func:`s2fft.quadrature.quad_weights_mw_theta_only`. - - Args: - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Weights computed for each :math:`\theta`. - - """ - w = torch.zeros(2 * L - 1, dtype=torch.complex128) - for i in range(-(L - 1), L): - w[i + L - 1] = mw_weights(i) - - w *= torch.exp(-1j * torch.arange(-(L - 1), L) * torch.pi / (2 * L - 1)) - wr = torch.real(torch.fft.fft(torch.fft.ifftshift(w), norm="backward")) / ( - 2 * L - 1 - ) - q = wr[:L] - q[: L - 1] += torch.flip(wr, dims=[0])[: L - 1] - - return q - - -def mw_weights(m: int) -> float: - r""" - Compute MW weights given as a function of index m. - - MW weights are defined by - - .. math:: - - w(m^\prime) = \int_0^\pi \text{d} \theta \sin \theta \exp(i m^\prime\theta), - - which can be computed analytically. - - Args: - m (int): Harmonic weight index. - - Returns: - float: MW weight. - - """ - if m == 1: - return 1j * torch.pi / 2 - - elif m == -1: - return -1j * torch.pi / 2 - - elif m % 2 == 0: - return 2 / (1 - m**2) - - else: - return 0 +_torch_wrapper.populate_namespace_by_wrapping_functions_in_module( + globals(), _quadrature_jax +) diff --git a/s2fft/utils/resampling_jax.py b/s2fft/utils/resampling_jax.py index cff96906..11f1073a 100644 --- a/s2fft/utils/resampling_jax.py +++ b/s2fft/utils/resampling_jax.py @@ -1,12 +1,12 @@ -from functools import partial +from functools import partial as _partial import jax.numpy as jnp -from jax import jit +from jax import jit as _jit from s2fft.sampling import s2_samples as samples -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def mw_to_mwss(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: r""" Convert signal on the sphere from MW sampling to MWSS sampling. @@ -38,7 +38,7 @@ def mw_to_mwss(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: return mw_to_mwss_phi(mw_to_mwss_theta(f_mw, L, spin), L) -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def mw_to_mwss_theta(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: r""" Convert :math:`\theta` component of signal on the sphere from MW sampling to @@ -96,7 +96,7 @@ def mw_to_mwss_theta(f_mw: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: return unextend(f_mwss_ext, L, sampling="mwss") -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def mw_to_mwss_phi(f_mw: jnp.ndarray, L: int) -> jnp.ndarray: r""" Convert :math:`\phi` component of signal on the sphere from MW sampling to @@ -142,7 +142,7 @@ def mw_to_mwss_phi(f_mw: jnp.ndarray, L: int) -> jnp.ndarray: ) -@partial(jit, static_argnums=(1, 3)) +@_partial(_jit, static_argnums=(1, 3)) def periodic_extension( f: jnp.ndarray, L: int, spin: int = 0, sampling: str = "mw" ) -> jnp.ndarray: @@ -233,7 +233,7 @@ def periodic_extension( ) -@partial(jit, static_argnums=(1, 2)) +@_partial(_jit, static_argnums=(1, 2)) def unextend(f_ext: jnp.ndarray, L: int, sampling: str = "mw") -> jnp.ndarray: r""" Unextend MW/MWSS sampled signal from :math:`\theta` domain @@ -271,7 +271,7 @@ def unextend(f_ext: jnp.ndarray, L: int, sampling: str = "mw") -> jnp.ndarray: ) -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def upsample_by_two_mwss(f: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: r""" Upsample MWSS sampled signal on the sphere defined on domain :math:`[0,\pi]` @@ -304,7 +304,7 @@ def upsample_by_two_mwss(f: jnp.ndarray, L: int, spin: int = 0) -> jnp.ndarray: return jnp.squeeze(f_ext) -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def upsample_by_two_mwss_ext(f_ext: jnp.ndarray, L: int) -> jnp.ndarray: r""" Upsample an extended MWSS sampled signal on the sphere defined on domain @@ -343,7 +343,7 @@ def upsample_by_two_mwss_ext(f_ext: jnp.ndarray, L: int) -> jnp.ndarray: ) -@partial(jit, static_argnums=(1)) +@_partial(_jit, static_argnums=(1)) def periodic_extension_spatial_mwss( f: jnp.ndarray, L: int, spin: int = 0 ) -> jnp.ndarray: diff --git a/s2fft/utils/resampling_torch.py b/s2fft/utils/resampling_torch.py index f5a44c97..e65ffed6 100644 --- a/s2fft/utils/resampling_torch.py +++ b/s2fft/utils/resampling_torch.py @@ -1,376 +1,6 @@ -import torch +from s2fft.utils import resampling_jax as _resampling_jax +from s2fft.utils import torch_wrapper as _torch_wrapper -from s2fft.sampling import s2_samples as samples - - -def mw_to_mwss(f_mw: torch.tensor, L: int, spin: int = 0) -> torch.tensor: - r""" - Convert signal on the sphere from MW sampling to MWSS sampling. - - Conversion is performed by first performing a period extension in - :math:`\theta` to :math:`2\pi`, followed by zero padding in harmonic space. The - resulting signal is then unextend back to the :math:`\theta` domain of - :math:`[0,\pi]`. Second, zero padding in harmonic space corresponding to - :math:`\phi` is performed. - - Torch implementation of :func:`~s2fft.resampling.mw_to_mwss`. - - Args: - f_mw (torch.tensor): Signal on the sphere sampled with MW sampling. - - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Signal on the sphere sampled with MWSS sampling. - - """ - if f_mw.ndim == 2: - return torch.squeeze( - mw_to_mwss_phi(mw_to_mwss_theta(torch.unsqueeze(f_mw, 0), L, spin), L) - ) - else: - return mw_to_mwss_phi(mw_to_mwss_theta(f_mw, L, spin), L) - - -def mw_to_mwss_theta(f_mw: torch.tensor, L: int, spin: int = 0) -> torch.tensor: - r""" - Convert :math:`\theta` component of signal on the sphere from MW sampling to - MWSS sampling. - - Conversion is performed by first performing a period extension in - :math:`\theta` to :math:`2\pi`, followed by zero padding in harmonic space. The - resulting signal is then unextend back to the :math:`\theta` domain of - :math:`[0,\pi]`. - - Torch implementation of :func:`~s2fft.resampling.mw_to_mwss_theta`. - - - Args: - f_mw (torch.tensor): Signal on the sphere sampled with MW sampling. - - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Raises: - ValueError: Input spherical signal must have shape matching MW sampling. - - Returns: - torch.tensor: Signal on the sphere with MWSS sampling in :math:`\theta` and MW - sampling in :math:`\phi`. - - """ - f_mw_ext = periodic_extension(f_mw, L, spin=spin, sampling="mw") - fmp_mwss_ext = torch.zeros( - (f_mw_ext.shape[0], 2 * L, 2 * L - 1), dtype=torch.complex128 - ) - - fmp_mwss_ext[:, 1:, :] = torch.fft.fftshift( - torch.fft.fft(f_mw_ext, axis=-2, norm="forward"), dim=[-2] - ) - - fmp_mwss_ext[:, 1:, :] = torch.einsum( - "...blp,l->...blp", - fmp_mwss_ext[:, 1:, :], - torch.exp( - -1j - * torch.arange(-(L - 1), L, dtype=torch.float64) - * torch.pi - / (2 * L - 1) - ), - ) - - f_mwss_ext = torch.conj( - torch.fft.fft( - torch.fft.ifftshift(torch.conj(fmp_mwss_ext), dim=[-2]), - axis=-2, - norm="backward", - ) - ) - - return unextend(f_mwss_ext, L, sampling="mwss") - - -def mw_to_mwss_phi(f_mw: torch.tensor, L: int) -> torch.tensor: - r""" - Convert :math:`\phi` component of signal on the sphere from MW sampling to - MWSS sampling. - - Conversion is performed by zero padding in harmonic space. - - Torch implementation of :func:`~s2fft.resampling.mw_to_mwss_phi`. - - - Note: - Can work with arbitrary number of :math:`\theta` samples. Hence, to convert - both :math:`(\theta,\phi)` sampling to MWSS, can use :func:`~mw_to_mwss_theta` - to first convert :math:`\theta` sampling before using this function to convert - the :math:`\phi` sampling. - - Args: - f_mw (torch.tensor): Signal on the sphere sampled with MW sampling in - :math:`\phi` and arbitrary number of samples in - - L (int): Harmonic band-limit. - - Raises: - ValueError: Input spherical signal must have number of samples in :math:`\phi` - matching MW sampling. - - Returns: - torch.tensor: Signal on the sphere with MWSS sampling in :math:`\phi` and - sampling in :math:`\theta` of the input signal. - - """ - f_mwss = torch.zeros((f_mw.shape[0], L + 1, 2 * L), dtype=torch.complex128) - f_mwss[:, :, 1:] = torch.fft.fftshift( - torch.fft.fft(f_mw, axis=-1, norm="forward"), dim=[-1] - ) - - return torch.conj( - torch.fft.fft( - torch.fft.ifftshift(torch.conj(f_mwss), dim=[-1]), - axis=-1, - norm="backward", - ) - ) - - -def periodic_extension( - f: torch.tensor, L: int, spin: int = 0, sampling: str = "mw" -) -> torch.tensor: - r""" - Perform period extension of MW/MWSS signal on the sphere in harmonic - domain, extending :math:`\theta` domain from :math:`[0,\pi]` to :math:`[0,2\pi]`. - Torch implementation of :func:`~s2fft.resampling.periodic_extension`. - - Args: - f (torch.tensor): Signal on the sphere sampled with MW/MWSS sampling scheme. - - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss"}. Defaults to "mw". - - Raises: - ValueError: Only MW/MWW sampling schemes supported. - - Returns: - torch.tensor: Signal on the sphere extended to :math:`\theta` domain - :math:`[0,2\pi]`, in same scheme (MW/MWSS) as input. - - """ - ntheta = samples.ntheta(L, sampling) - nphi = samples.nphi_equiang(L, sampling) - ntheta_ext = samples.ntheta_extension(L, sampling) - m_offset = 1 if sampling == "mwss" else 0 - - f_ext = torch.zeros((f.shape[0], ntheta_ext, nphi), dtype=torch.complex128) - f_ext[:, 0:ntheta, 0:nphi] = f[:, 0:ntheta, 0:nphi] - f_ext = torch.fft.fftshift(torch.fft.fft(f_ext, dim=-1, norm="backward"), dim=[-1]) - - f_ext[ - :, - L + m_offset : 2 * L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ] = torch.flip( - f_ext[ - :, - m_offset : L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ], - dims=[-2], - ) - f_ext[ - :, - L + m_offset : 2 * L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ] *= (-1) ** (torch.arange(-(L - 1), L)) - if hasattr(spin, "size"): - f_ext[ - :, - L + m_offset : 2 * L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ] = torch.einsum( - "nlm,n->nlm", - f_ext[ - :, - L + m_offset : 2 * L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ], - (-1) ** spin, - ) - else: - f_ext[ - :, - L + m_offset : 2 * L - 1 + m_offset, - m_offset : 2 * L - 1 + m_offset, - ] *= (-1) ** spin - - return ( - torch.conj( - torch.fft.fft( - torch.fft.ifftshift(torch.conj(f_ext), dim=[-1]), - axis=-1, - norm="backward", - ) - ) - / nphi - ) - - -def unextend(f_ext: torch.tensor, L: int, sampling: str = "mw") -> torch.tensor: - r""" - Unextend MW/MWSS sampled signal from :math:`\theta` domain - :math:`[0,2\pi]` to :math:`[0,\pi]`. - - Args: - f_ext (torch.tensor): Signal on the sphere sampled on extended :math:`\theta` - domain :math:`[0,2\pi]`. - - L (int): Harmonic band-limit. - - sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss"}. Defaults to "mw". - - Raises: - ValueError: Only MW/MWW sampling schemes supported. - - ValueError: Period extension must have correct shape. - - Returns: - torch.tensor: Signal on the sphere sampled on :math:`\theta` domain - :math:`[0,\pi]`. - - """ - if sampling.lower() == "mw": - return f_ext[:, 0:L, :] - - elif sampling.lower() == "mwss": - return f_ext[:, 0 : L + 1, :] - - else: - raise ValueError( - "Only mw and mwss supported for periodic extension " - f"(not sampling={sampling})" - ) - - -def upsample_by_two_mwss(f: torch.tensor, L: int, spin: int = 0) -> torch.tensor: - r""" - Upsample MWSS sampled signal on the sphere defined on domain :math:`[0,\pi]` - by a factor of two. - - Upsampling is performed by a periodic extension in :math:`\theta` to - :math:`[0,2\pi]`, followed by zero-padding in harmonic space, followed by - unextending :math:`\theta` domain back to :math:`[0,\pi]`. - - Torch implementation of :func:`~s2fft.resampling.upsample_by_two_mwss`. - - Args: - f (torch.tensor): Signal on the sphere sampled with MWSS sampling scheme, sampled - at resolution L. - - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Signal on the sphere sampled with MWSS sampling scheme, sampling at - resolution 2*L. - - """ - if f.ndim == 2: - f = torch.unsqueeze(f, 0) - f_ext = periodic_extension_spatial_mwss(f, L, spin) - f_ext = upsample_by_two_mwss_ext(f_ext, L) - f_ext = unextend(f_ext, 2 * L, sampling="mwss") - return torch.squeeze(f_ext) - - -def upsample_by_two_mwss_ext(f_ext: torch.tensor, L: int) -> torch.tensor: - r""" - Upsample an extended MWSS sampled signal on the sphere defined on domain - :math:`[0,2\pi]` by a factor of two. - - Upsampling is performed by zero-padding in harmonic space. Torch implementation of - :func:`~s2fft.resampling.upsample_by_two_mwss_ext`. - - Args: - f_ext (torch.tensor): Signal on the sphere sampled on extended MWSS sampling - scheme on domain :math:`[0,2\pi]`, sampled at resolution L. - - L (int): Harmonic band-limit. - - Returns: - torch.tensor: Signal on the sphere sampled on extended MWSS sampling scheme on - domain :math:`[0,2\pi]`, sampling at resolution 2*L. - - """ - nphi = 2 * L - ntheta_ext = 2 * L - - f_ext = torch.fft.fftshift(torch.fft.fft(f_ext, axis=-2, norm="forward"), dim=[-2]) - - ntheta_ext_up = 2 * ntheta_ext - f_ext_up = torch.zeros( - (f_ext.shape[0], ntheta_ext_up, nphi), dtype=torch.complex128 - ) - f_ext_up[:, L : ntheta_ext + L, :nphi] = f_ext[:, 0:ntheta_ext, :nphi] - return torch.conj( - torch.fft.fft( - torch.fft.ifftshift(torch.conj(f_ext_up), dim=[-2]), - axis=-2, - norm="backward", - ) - ) - - -def periodic_extension_spatial_mwss( - f: torch.tensor, L: int, spin: int = 0 -) -> torch.tensor: - r""" - Perform period extension of MWSS signal on the sphere in spatial domain, - extending :math:`\theta` domain from :math:`[0,\pi]` to :math:`[0,2\pi]`. - - For the MWSS sampling scheme, it is possible to do the period extension in - :math:`\theta` in the spatial domain. This is not possible for the MW sampling - scheme. - - Torch implementation of :func:`~s2fft.resampling.periodic_extension_spatial_mwss`. - - Args: - f (torch.tensor): Signal on the sphere sampled with MWSS sampling scheme. - - L (int): Harmonic band-limit. - - spin (int, optional): Harmonic spin. Defaults to 0. - - Returns: - torch.tensor: Signal on the sphere extended to :math:`\theta` domain - :math:`[0,2\pi]`, in MWSS sampling scheme. - - """ - ntheta = L + 1 - nphi = 2 * L - ntheta_ext = 2 * L - - f_ext = torch.zeros((f.shape[0], ntheta_ext, nphi), dtype=torch.complex128) - f_ext[:, 0:ntheta, 0:nphi] = f[:, 0:ntheta, 0:nphi] - if hasattr(spin, "size"): - f_ext[:, ntheta:, 0 : 2 * L] = torch.einsum( - "btp,b->btp", - torch.fft.fftshift( - torch.flip(f[:, 1 : ntheta - 1, 0 : 2 * L], dims=[-2]), dim=[-1] - ), - (-1) ** spin, - ) - else: - f_ext[:, ntheta:, 0 : 2 * L] = (-1) ** spin * torch.fft.fftshift( - torch.flip(f[:, 1 : ntheta - 1, 0 : 2 * L], dims=[-2]), dim=[-1] - ) - return f_ext +_torch_wrapper.populate_namespace_by_wrapping_functions_in_module( + globals(), _resampling_jax +) diff --git a/s2fft/utils/signal_generator.py b/s2fft/utils/signal_generator.py index 799de9b9..d3e4aca9 100644 --- a/s2fft/utils/signal_generator.py +++ b/s2fft/utils/signal_generator.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -import torch +from numpy.typing import DTypeLike from s2fft.sampling import s2_samples as samples from s2fft.sampling import so3_samples as wigner_samples @@ -10,7 +10,8 @@ def complex_normal( rng: np.random.Generator, size: int | tuple[int], - var: float, + var: float = 1.0, + dtype: DTypeLike = np.complex128, ) -> np.ndarray: """ Generate array of samples from zero-mean complex normal distribution. @@ -23,14 +24,19 @@ def complex_normal( rng: Numpy random generator object to generate samples using. size: Output shape of array to generate. var: Variance of complex normal distribution to generate samples from. + dtype: Data type of generated array. Returns: - Complex-valued array of shape `size` contained generated samples. + Complex-valued array of shape `size` and data type `dtype` containing generated + samples. """ - return (rng.standard_normal(size) + 1j * rng.standard_normal(size)) * ( - var / 2 - ) ** 0.5 + # NumPy handling of scalars differs for floating and complex types with latter + # being returned as complex scalar objects rather than arrays of shape () so we + # use asarray here to ensure output is always an array + return np.asarray( + (rng.standard_normal(size) + 1j * rng.standard_normal(size)) * (var / 2) ** 0.5 + ).astype(dtype) def complex_el_and_m_indices(L: int, min_el: int) -> tuple[np.ndarray, np.ndarray]: @@ -73,8 +79,7 @@ def generate_flm( L_lower: int = 0, spin: int = 0, reality: bool = False, - using_torch: bool = False, -) -> np.ndarray | torch.Tensor: +) -> np.ndarray: r""" Generate a 2D set of random harmonic coefficients. @@ -92,8 +97,6 @@ def generate_flm( reality (bool, optional): Reality of signal. Defaults to False. - using_torch (bool, optional): Desired frontend functionality. Defaults to False. - Returns: np.ndarray: Random set of spherical harmonic coefficients. @@ -117,7 +120,7 @@ def generate_flm( else: # Non-real signal so generate independent complex coefficients for negative m flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2) - return torch.from_numpy(flm) if using_torch else flm + return flm def generate_flmn( @@ -126,8 +129,7 @@ def generate_flmn( N: int = 1, L_lower: int = 0, reality: bool = False, - using_torch: bool = False, -) -> np.ndarray | torch.Tensor: +) -> np.ndarray: r""" Generate a 3D set of random Wigner coefficients. @@ -146,8 +148,6 @@ def generate_flmn( reality (bool, optional): Reality of signal. Defaults to False. - using_torch (bool, optional): Desired frontend functionality. Defaults to False. - Returns: np.ndarray: Random set of Wigner coefficients. @@ -198,4 +198,4 @@ def generate_flmn( flmn[N - 1 + n, el_indices, L - 1 - m_indices] = complex_normal( rng, len_indices, var=2 ) - return torch.from_numpy(flmn) if using_torch else flmn + return flmn diff --git a/s2fft/utils/torch_wrapper.py b/s2fft/utils/torch_wrapper.py new file mode 100644 index 00000000..1f12894f --- /dev/null +++ b/s2fft/utils/torch_wrapper.py @@ -0,0 +1,261 @@ +""" +Utilities for wrapping JAX functions for use in PyTorch. + +Based on Gist by Matt Johnson at +https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9 + +and jax2torch package by Phil Wang +https://github.com/lucidrains/jax2torch + +which is released under a MIT license + +Copyright (c) 2021 Phil Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations + +from functools import wraps +from inspect import getmembers, isroutine, signature +from types import ModuleType +from typing import Any, Callable, Dict, List, Tuple, TypeVar, Union + +import jax +import jax.dlpack +from jax.tree_util import tree_map + +try: + import torch + import torch.utils.dlpack + from torch import Tensor + + TORCH_AVAILABLE = True +except ImportError: + Tensor = None + TORCH_AVAILABLE = False + +T = TypeVar("T") +PyTree = Union[Dict[Any, "PyTree"], List["PyTree"], Tuple["PyTree"], T] + + +def check_torch_available() -> None: + """Raise an error if Torch is not importable.""" + if not TORCH_AVAILABLE: + msg = ( + "torch needs to be installed to use torch wrapper functionality but could\n" + "not be imported. Install s2fft with torch extra using:\n" + " pip install s2fft[torch]\n" + "to allow use of torch wrapper functionality." + ) + raise RuntimeError(msg) + + +def jax_array_to_torch_tensor(jax_array: jax.Array) -> Tensor: + """ + Convert from JAX array to Torch tensor via mutual DLPack support. + + Args: + jax_array: JAX array to convert. + + Returns: + Torch tensor object with equivalent data to `jax_array`. + + """ + try: + return torch.utils.dlpack.from_dlpack(jax_array) + except AttributeError: + # jax.Array instances in earlier JAX versions lack a __dlpack_device__ attribute + # and require explicitly packing into a DLPack capsule with jax.dlpack.to_dlpack + return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jax_array)) + + +def torch_tensor_to_jax_array(torch_tensor: Tensor) -> jax.Array: + """ + Convert from Torch tensor to JAX array via mutual DLPack support. + + Args: + torch_tensor: Torch tensor to convert. + + Returns: + JAX array object with equivalent data to `torch_tensor`. + + """ + # JAX currently only support DLPack arrays with trivial strides so force torch + # tensor to be contiguous before DLPack conversion + # https://github.com/google/jax/issues/8082 + torch_tensor = torch_tensor.contiguous() + # Torch does lazy conjugation using flag bits and DLPack does not support this + # https://github.com/data-apis/array-api-compat/issues/173#issuecomment-2272192054 + # so we explicitly resolve any conjugacy operations implied by bit before conversion + torch_tensor = torch_tensor.resolve_conj() + # DLPack compatibility does not support tensors that require gradient so detach. As + # this intended for use when wrapping JAX code detaching tensor from gradient values + # should not be problematic as derivatives will be separately routed via JAX + torch_tensor = torch_tensor.detach() + try: + return jax.dlpack.from_dlpack(torch_tensor) + except TypeError: + # earlier JAX versions require explicitly converting external arrays to + # DLPack capsule before passing to jax.dlpack.from_dlpack + return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) + + +def tree_map_jax_array_to_torch_tensor( + jax_pytree: PyTree[jax.Array], +) -> PyTree[Tensor]: + """ + Convert from a pytree with JAX arrays to corresponding pytree with Torch tensors. + + Args: + jax_pytree: Pytree of JAX arrays or non-array values. + + Returns: + Pytree with equivalent structure but any JAX arrays mapped to Torch tensors. + + """ + return tree_map( + lambda t: jax_array_to_torch_tensor(t) if isinstance(t, jax.Array) else t, + jax_pytree, + ) + + +def tree_map_torch_tensor_to_jax_array( + torch_pytree: PyTree[Tensor], +) -> PyTree[jax.Array]: + """ + Convert from a pytree with Torch tensors to corresponding pytree with JAX arrays. + + Args: + torch_pytree: Pytree of Torch tensorss or non-array values. + + Returns: + Pytree with equivalent structure but any Torch tensors mapped to JAX arrays. + + """ + return tree_map( + lambda t: torch_tensor_to_jax_array(t) if isinstance(t, Tensor) else t, + torch_pytree, + ) + + +def wrap_as_torch_function( + jax_function: Callable, differentiable_argnames: None | tuple[str] = None +) -> Callable: + """ + Wrap a function implemented using JAX API to be callable within Torch. + + Deals with conversion of argument(s) from JAX array(s) to Torch tensor(s), and of + return value(s) from JAX array(s) to Torch tensor(s), as well as recording + context needed to compute reverse-mode gradients in Torch using JAX automatic + differentiation support if differentiable arguments are present. + + Args: + jax_function: JAX function to wrap. + differentiable_argnames: Names of arguments of `jax_function` which function + output(s) are differentiable with respect to, and gradients should be + compute with respect to in Torch backwards pass. If `None` (the default) + the names of all arguments which are annotated as being `jax.Array` instances + will be used. + + Returns: + Wrapped function callable from Torch. + + """ + sig = signature(jax_function) + if differentiable_argnames is None: + differentiable_argnames = tuple( + name + for name, param in sig.parameters.items() + if isinstance(param.annotation, type) + and issubclass(param.annotation, jax.Array) + ) + for argname in differentiable_argnames: + if argname not in sig.parameters: + msg = f"{argname} passed is not a valid argument to {jax_function}" + raise ValueError(msg) + + @wraps(jax_function) + def torch_function(*args, **kwargs): + check_torch_available() + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + differentiable_args = tuple( + bound_args.arguments[argname] for argname in differentiable_argnames + ) + + def jax_function_diff_args_only(*differentiable_args): + for key, value in zip(differentiable_argnames, differentiable_args): + bound_args.arguments[key] = value + return jax_function(*bound_args.args, **bound_args.kwargs) + + class WrappedJaxFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + args = tree_map_torch_tensor_to_jax_array(args) + primals_out, ctx.vjp = jax.vjp(jax_function_diff_args_only, *args) + return tree_map_jax_array_to_torch_tensor(primals_out) + + @staticmethod + def backward(ctx, *grad_outputs): + # JAX and PyTorch use different conventions for derivatives of complex + # functions (see https://github.com/jax-ml/jax/issues/4891) so we need + # to conjugate the inputs to and outputs from VJP to get equivalent + # behaviour to backward method on torch tensors + grad_outputs = tree_map(lambda g: g.conj(), grad_outputs) + jax_grad_outputs = tree_map_torch_tensor_to_jax_array(grad_outputs) + jax_grad_inputs = ctx.vjp(*jax_grad_outputs) + grad_inputs = tree_map_jax_array_to_torch_tensor(jax_grad_inputs) + return tree_map(lambda g: g.conj(), grad_inputs) + + return WrappedJaxFunction.apply(*differentiable_args) + + docstring_replacements = { + "JAX": "Torch", + "jnp.ndarray": "torch.Tensor", + "jax.Array": "torch.Tensor", + } + if torch_function.__doc__ is not None: + for original, new in docstring_replacements.items(): + torch_function.__doc__ = torch_function.__doc__.replace(original, new) + + torch_function.__annotations__ = torch_function.__annotations__.copy() + for name, annotation in torch_function.__annotations__.items(): + if isinstance(annotation, type) and issubclass(annotation, jax.Array): + torch_function.__annotations__[name] = Tensor + + return torch_function + + +def populate_namespace_by_wrapping_functions_in_module( + namespace: dict, module: ModuleType +) -> None: + """ + Populate a namespace by wrapping all (JAX) functions in a module as Torch functions. + + Args: + namespace: Namespace to define wrapped functions in. + module: Source module for (JAX) functions to wrap. Note all functions in this + module without a preceding underscore in their name will be wrapped + irrespective of whether they are defined in the module or not. + + """ + for name, function in getmembers(module, isroutine): + if not name.startswith("_"): + namespace[name] = wrap_as_torch_function(function) diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 20096d17..115a4b24 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -23,7 +23,7 @@ reality_to_test = [True, False] methods_to_test = ["numpy", "jax", "torch"] recursions_to_test = ["price-mcewen", "risbo", "auto"] -iter_to_test = [0, 3] +iter_to_test = [0, 1] @pytest.mark.parametrize("L", L_to_test) diff --git a/tests/test_spherical_transform.py b/tests/test_spherical_transform.py index 1645b14d..0caf5f2a 100644 --- a/tests/test_spherical_transform.py +++ b/tests/test_spherical_transform.py @@ -3,6 +3,7 @@ import numpy as np import pyssht as ssht import pytest +import torch from s2fft.recursions.price_mcewen import generate_precomputes from s2fft.sampling import s2_samples as samples @@ -15,7 +16,7 @@ spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] sampling_to_test = ["mw", "mwss", "dh", "gl"] -method_to_test = ["numpy", "jax"] +method_to_test = ["numpy", "jax", "torch"] reality_to_test = [False, True] multiple_gpus = [False, True] @@ -59,7 +60,7 @@ def test_transform_inverse( else: precomps = None f = spherical.inverse( - flm, + torch.from_numpy(flm) if method == "orch" else flm, L, spin, sampling=sampling, @@ -87,10 +88,9 @@ def test_transform_inverse_healpix( flm = flm_generator(L=L, spin=0, reality=True) flm_hp = samples.flm_2d_to_hp(flm, L) f_check = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) - precomps = generate_precomputes(L, 0, sampling, nside, False) f = spherical.inverse( - flm, + torch.from_numpy(flm) if method == "torch" else flm, L, spin=0, nside=nside, @@ -142,8 +142,9 @@ def test_transform_forward( precomps = generate_precomputes(L, spin, sampling, None, True, L_lower) else: precomps = None + flm_check = spherical.forward( - f, + torch.from_numpy(f) if method == "torch" else f, L, spin, sampling=sampling, @@ -173,10 +174,9 @@ def test_transform_forward_healpix( flm = flm_generator(L=L, spin=0, reality=True) flm_hp = samples.flm_2d_to_hp(flm, L) f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) - precomps = generate_precomputes(L, 0, sampling, nside, True) flm_check = spherical.forward( - f, + torch.from_numpy(f) if method == "torch" else f, L, spin=0, nside=nside, diff --git a/tests/test_torch_wrapper.py b/tests/test_torch_wrapper.py new file mode 100644 index 00000000..40d51811 --- /dev/null +++ b/tests/test_torch_wrapper.py @@ -0,0 +1,203 @@ +from importlib import reload +from unittest.mock import MagicMock, patch + +import jax +import numpy as np +import pytest +import torch +from jax.test_util import check_grads +from jax.tree_util import tree_all, tree_map + +from s2fft.utils import torch_wrapper +from s2fft.utils.signal_generator import complex_normal + +jax.config.update("jax_enable_x64", True) + + +def sum_abs_square(x: jax.Array) -> float: + return (abs(x) ** 2).sum() + + +def log_sum_exp(x: jax.Array) -> float: + max_x = x.max() + return max_x + jax.numpy.log(jax.numpy.exp(x - max_x).sum()) + + +def cubic(x: jax.Array) -> jax.Array: + return x**3 - 2 * x**2 + 3 * x - 1 + + +def conj(x: jax.Array) -> jax.Array: + return x.conj() + + +DTYPES = ["float32", "float64", "complex64", "complex128"] + +INPUT_SHAPES = [(), (1,), (2,), (3, 4)] + +PYTREE_STRUCTURES = [(), [(), ((1,), (2, 3))], {"a": [(1,), ()], "b": {"0": (1, 2)}}] + +JAX_SINGLE_ARG_FUNCTIONS = [sum_abs_square, log_sum_exp, cubic, conj] + + +def generate_standard_normal(rng, shape, dtype): + if np.issubdtype(dtype, np.floating): + return rng.standard_normal(shape, dtype=dtype) + elif np.issubdtype(dtype, np.complexfloating): + return complex_normal(rng, shape, dtype=dtype) + else: + msg = f"dtype {dtype} must be a floating or complex floating data type" + raise ValueError(msg) + + +def generate_pytree(rng, converter, dtype, structure): + if isinstance(structure, tuple): + if structure == () or all(isinstance(child, int) for child in structure): + return converter(generate_standard_normal(rng, structure, dtype)) + else: + return tuple( + generate_pytree(rng, converter, dtype, child) for child in structure + ) + elif isinstance(structure, list): + return [generate_pytree(rng, converter, dtype, child) for child in structure] + elif isinstance(structure, dict): + return { + key: generate_pytree(rng, converter, dtype, value) + for key, value in structure.items() + } + else: + raise TypeError( + f"pytree structure with type {type(structure)} not of recognised type" + ) + + +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_jax_array_to_torch_tensor(rng, input_shape, dtype): + x_jax = jax.numpy.asarray(generate_standard_normal(rng, input_shape, dtype=dtype)) + x_torch = torch_wrapper.jax_array_to_torch_tensor(x_jax) + assert isinstance(x_torch, torch.Tensor) + assert x_torch.dtype == getattr(torch, dtype) + np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch)) + + +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_torch_tensor_to_jax_array(rng, input_shape, dtype): + x_torch = torch.from_numpy(generate_standard_normal(rng, input_shape, dtype=dtype)) + x_jax = torch_wrapper.torch_tensor_to_jax_array(x_torch) + assert isinstance(x_jax, jax.Array) + assert x_jax.dtype == dtype + np.testing.assert_allclose(np.asarray(x_jax), np.asarray(x_torch)) + + +@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_tree_map_jax_array_to_torch_tensor(rng, pytree_structure, dtype): + jax_pytree = generate_pytree(rng, jax.numpy.asarray, dtype, pytree_structure) + torch_pytree = torch_wrapper.tree_map_jax_array_to_torch_tensor(jax_pytree) + assert tree_all( + tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree), + ) + assert tree_all( + tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree), + ) + assert tree_all( + tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree), + ) + assert tree_all( + tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree), + ) + assert tree_all( + tree_map( + lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)), + torch_pytree, + jax_pytree, + ) + ) + + +@pytest.mark.parametrize("pytree_structure", PYTREE_STRUCTURES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_tree_map_torch_tensor_to_jax_array(rng, pytree_structure, dtype): + torch_pytree = generate_pytree(rng, torch.from_numpy, dtype, pytree_structure) + jax_pytree = torch_wrapper.tree_map_torch_tensor_to_jax_array(torch_pytree) + assert tree_all( + tree_map(lambda leaf: isinstance(leaf, jax.Array), jax_pytree), + ) + assert tree_all( + tree_map(lambda leaf: leaf.dtype == dtype, jax_pytree), + ) + assert tree_all( + tree_map(lambda leaf: isinstance(leaf, torch.Tensor), torch_pytree), + ) + assert tree_all( + tree_map(lambda leaf: leaf.dtype == getattr(torch, dtype), torch_pytree), + ) + assert tree_all( + tree_map( + lambda leaf_1, leaf_2: np.allclose(np.asarray(leaf_1), np.asarray(leaf_2)), + torch_pytree, + jax_pytree, + ) + ) + + +@pytest.mark.parametrize("jax_function", JAX_SINGLE_ARG_FUNCTIONS) +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) +@pytest.mark.parametrize("dtype", DTYPES) +def test_wrap_as_torch_function_single_arg(rng, input_shape, dtype, jax_function): + x_numpy = generate_standard_normal(rng, input_shape, dtype=dtype) + x_jax = jax.numpy.asarray(x_numpy) + y_jax, vjp_jax_function = jax.vjp(jax_function, x_jax) + x_torch = torch.tensor(x_numpy, requires_grad=True) + torch_function = torch_wrapper.wrap_as_torch_function(jax_function) + y_torch = torch_function(x_torch) + assert isinstance(y_torch, torch.Tensor) + y_dtype = str(y_jax.dtype) + assert y_torch.dtype == getattr(torch, y_dtype) + np.testing.assert_allclose(np.asarray(y_jax), np.asarray(y_torch.detach())) + y_bar = generate_standard_normal(rng, y_jax.shape, dtype=y_dtype) + # JAX and PyTorch use different conventions for derivatives of complex functions + # (see https://github.com/jax-ml/jax/issues/4891) so we need to conjugate the + # inputs to and outputs from VJP to get equivalent behaviour to backward method on + # torch tensors + x_bar_jax = vjp_jax_function(y_bar.conj())[0].conj() + y_torch.backward(torch.from_numpy(y_bar)) + assert x_torch.grad.dtype == getattr(torch, dtype) + np.testing.assert_allclose( + np.asarray(x_bar_jax), np.asarray(x_torch.grad.resolve_conj()) + ) + + +# torch.autograd.gradcheck tolerances calibrated for double precision so only do checks +# with double precision floating point types +@pytest.mark.parametrize("jax_function", JAX_SINGLE_ARG_FUNCTIONS) +@pytest.mark.parametrize("input_shape", INPUT_SHAPES) +@pytest.mark.parametrize("dtype", ["float64", "complex128"]) +def test_wrap_as_torch_function_single_arg_autograd_check( + rng, input_shape, dtype, jax_function +): + x_numpy = generate_standard_normal(rng, input_shape, dtype=dtype) + x_jax = jax.numpy.asarray(x_numpy) + check_grads(jax_function, (x_jax,), order=1) + x_torch = torch.tensor(x_numpy, requires_grad=True) + torch_function = torch_wrapper.wrap_as_torch_function(jax_function) + torch.autograd.gradcheck(torch_function, x_torch) + + +def test_check_pytorch_available(): + try: + with patch.dict("sys.modules", torch=None): + reload(torch_wrapper) + with pytest.raises(RuntimeError, match="torch needs to be installed"): + torch_wrapper.check_torch_available() + with patch.dict("sys.modules", torch=MagicMock()): + reload(torch_wrapper) + # We should not get an exception here irrespective of whether torch is + # installed + torch_wrapper.check_torch_available() + finally: + # Ensure torch_wrapper always reloaded with original state irrespective of test + # passing or failing + reload(torch_wrapper) diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index ae0a9a46..28ec66ba 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pytest import so3 @@ -8,6 +9,8 @@ from s2fft.precompute_transforms.wigner import forward, inverse from s2fft.sampling import so3_samples as samples +jax.config.update("jax_enable_x64", True) + L_to_test = [6] N_to_test = [2, 6] nside_to_test = [4] diff --git a/tests/test_wigner_transform.py b/tests/test_wigner_transform.py index 26f20d71..4126e4bb 100644 --- a/tests/test_wigner_transform.py +++ b/tests/test_wigner_transform.py @@ -16,9 +16,16 @@ N_to_test = [2] L_lower_to_test = [0, 2] sampling_to_test = ["mw", "mwss", "dh", "gl"] -method_to_test = ["numpy", "jax"] +method_to_test = ["numpy", "jax", "torch"] reality_to_test = [False, True] +_generate_precomputes_functions = { + "jax": generate_precomputes_wigner_jax, + "numpy": generate_precomputes_wigner, + # torch method wraps jax so use jax to generate precomputess + "torch": generate_precomputes_wigner_jax, +} + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @@ -38,15 +45,9 @@ def test_inverse_wigner_transform( ): flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality) f_check = base_wigner.inverse(flmn, L, N, L_lower, sampling, reality) - - if method.lower() == "jax": - precomps = generate_precomputes_wigner_jax( - L, N, sampling, None, False, reality, L_lower - ) - else: - precomps = generate_precomputes_wigner( - L, N, sampling, None, False, reality, L_lower - ) + precomps = _generate_precomputes_functions[method]( + L, N, sampling, None, False, reality, L_lower + ) f = wigner.inverse(flmn, L, N, None, sampling, method, reality, precomps, L_lower) np.testing.assert_allclose(f, f_check, atol=1e-14) @@ -69,15 +70,9 @@ def test_forward_wigner_transform( ): flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality) f = base_wigner.inverse(flmn, L, N, L_lower, sampling, reality) - - if method.lower() == "jax": - precomps = generate_precomputes_wigner_jax( - L, N, sampling, None, True, reality, L_lower - ) - else: - precomps = generate_precomputes_wigner( - L, N, sampling, None, True, reality, L_lower - ) + precomps = _generate_precomputes_functions[method]( + L, N, sampling, None, True, reality, L_lower + ) flmn_check = wigner.forward( f, L, N, None, sampling, method, reality, precomps, L_lower )