Skip to content

Updating Healpix CUDA primitive#290

Open
ASKabalan wants to merge 43 commits intomainfrom
ASKabalan
Open

Updating Healpix CUDA primitive#290
ASKabalan wants to merge 43 commits intomainfrom
ASKabalan

Conversation

@ASKabalan
Copy link
Collaborator

@ASKabalan ASKabalan commented Mar 26, 2025

Adding a few updates

  • Updating to the newest custom call API (API 4) using FFI
  • implementing a grad rule for healpix cuda FFT
  • Implementing a Batching rule

A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform

I will be doing that next

@ASKabalan ASKabalan marked this pull request as draft March 26, 2025 16:25
@ASKabalan ASKabalan marked this pull request as ready for review March 28, 2025 16:08
@ASKabalan
Copy link
Collaborator Author

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

  1. Updated the binding API to the newest FFI
  2. Added a vmap implementation of the cuda primitive
  3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)
  4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100
  5. Removed two files which are now no longer needed with the FFI API (kernel helpers) (so maybe they should be removed from the license section)
  6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html)

2. Added a vmap implementation of the cuda primitive

3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)

4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100

5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section)

6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Hi @ASKabalan, sorry for the delay in getting back to you.

This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.

With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.

I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.

Comment on lines 150 to 151
flm_hp = samples.flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix") here instead of going via healpy? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.

ASKabalan and others added 2 commits April 16, 2025 11:17
Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
@matt-graham
Copy link
Collaborator

I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with

pytest tests/test_healpix_ffts.py

consistently the tests hang when trying to run the first test_healpix_fft_cuda instance.

Running just the IFFT tests with

pytest tests/test_healpix_ffts.py::test_healpix_ifft_cuda

the tests for both set of test parameters pass.

Trying to dig into this a bit, running the following locally

import healpy
import jax
import s2fft
import numpy

jax.config.update("jax_enable_x64", True)

seed = 20250416
nside = 4
L = 2 * nside
reality = False

rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready()

raises an error

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

so it looks like there is some memory addressing issue somewhere in the healpix_fft_cuda implementation?

@ASKabalan
Copy link
Collaborator Author

Thank you

I was able to reproduce with 12.4.1 but not locally with 12.4

I will take a look

@codecov
Copy link

codecov bot commented Jun 19, 2025

Codecov Report

❌ Patch coverage is 75.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 96.14%. Comparing base (b239a11) to head (3bcb69a).

Files with missing lines Patch % Lines
s2fft/utils/healpix_ffts.py 66.66% 3 Missing ⚠️
s2fft/utils/jax_primitive.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #290      +/-   ##
==========================================
- Coverage   96.50%   96.14%   -0.36%     
==========================================
  Files          32       32              
  Lines        3434     3453      +19     
==========================================
+ Hits         3314     3320       +6     
- Misses        120      133      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ASKabalan
Copy link
Collaborator Author

@matt-graham Hey
I am picking were i left off
So it seems that there is an error when building with python3.8
Doesn't seem to be coming from my code.
It seems to be because of some compile error when compiling sht

I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway

@matt-graham
Copy link
Collaborator

@matt-graham Hey I am picking were i left off So it seems that there is an error when building with python3.8 Doesn't seem to be coming from my code. It seems to be because of some compile error when compiling sht

Hi @ASKabalan. Do you mean so3 rather than (py)ssht? From a quick look at the logs of the failing Actions workflow job on Python 3.8 / ubuntu-latest it appears like it's an error with building so3 (ERROR: Failed building wheel for so3). If so this is likely the same issue as described in #308. I've opened a PR to try to fix this upstream in so3 (astro-informatics/so3#31).

I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway

Yes agreed we should drop Python 3.8 from test matrix - we have an open pull request #305 to update to only supporting Python 3.11+ but this is partially blocked by #212 as the tests currently exit with fatal errors when running on MacOS / Python 3.9+ due to an incompatibility between the OpenMP runtime's the MacOS wheels for healpy and PyTorch are built for (healpy/healpy#1012)

Add comprehensive documentation and fix dependency issues for CUDA FFT integration.

This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in  to ensure proper installation and functionality.

Key changes include:
- no more CUDA Malloc .. all memory is allocated in Python by XLA
- Added detailed docstrings to C++ header files
- Enhanced inline comments in C++ source files to explain complex logic and algorithms.
- Updated to relax JAX version dependency, resolving installation issues.
- Refined docstrings and comments in Python files for clarity and consistency.
- Cleaned up debug print statements
matt-graham and others added 9 commits November 11, 2025 03:14
* Update Python version used in docs workflow

* Trigger docs workflow on pull-requests

* Deploy only on push to main
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.0.0 to 3.0.1.
- [Release notes](https://github.com/pypa/cibuildwheel/releases)
- [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md)
- [Commits](pypa/cibuildwheel@v3.0.0...v3.0.1)

---
updated-dependencies:
- dependency-name: pypa/cibuildwheel
  dependency-version: 3.0.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.0.1 to 3.1.3.
- [Release notes](https://github.com/pypa/cibuildwheel/releases)
- [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md)
- [Commits](pypa/cibuildwheel@v3.0.1...v3.1.3)

---
updated-dependencies:
- dependency-name: pypa/cibuildwheel
  dependency-version: 3.1.3
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Update custom_ops.py

Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.

* Update custom_ops.py

Removed commented lines for linting purposes

* Removing now unused imports

---------

Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
Bumps [actions/checkout](https://github.com/actions/checkout) from 4.2.2 to 5.0.0.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](actions/checkout@v4.2.2...v5.0.0)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 5.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4 to 5.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](actions/download-artifact@v4...v5)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.1.3 to 3.1.4.
- [Release notes](https://github.com/pypa/cibuildwheel/releases)
- [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md)
- [Commits](pypa/cibuildwheel@v3.1.3...v3.1.4)

---
updated-dependencies:
- dependency-name: pypa/cibuildwheel
  dependency-version: 3.1.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
@ASKabalan
Copy link
Collaborator Author

@matt-graham
Took some time for to get back at this
But I think that the race condition is now fixed

@ASKabalan
Copy link
Collaborator Author

@matt-graham I think this is now good to go
I have issues with mac os but I am not sure I am the root cause

Otherwise I made a much more robust out of place shift in the CUDA kernels by fusing two of the kernels I had earlier

In short there is no longer need for synchronization
This can be merged IMO (it also works with latests JAX)

@matt-graham
Copy link
Collaborator

matt-graham commented Dec 18, 2025

@matt-graham I think this is now good to go I have issues with mac os but I am not sure I am the root cause

Otherwise I made a much more robust out of place shift in the CUDA kernels by fusing two of the kernels I had earlier

In short there is no longer need for synchronization This can be merged IMO (it also works with latests JAX)

Hi @ASKabalan thanks for your updates, and apologies for being very slow in getting to reviewing this.

Generally the changes look good, but there still seems to be some memory related issues occuring for me in testing. Specifically:

When running pytest tests/test_healpix_ffts.py with the package built from tip of this PR branch (3bcb69a) I am occassionally but not consistently getting segmentation faults:

tests/test_healpix_ffts.py::test_healpix_ifft_cuda_transforms[8966433580120847635-8] Fatal Python error: Segmentation fault

Thread 0x0000ffff975052e0 (most recent call first):
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2165 in default_process_primitive
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2117 in process_primitive
  File Segmentation fault (core dumped)
tests/test_healpix_ffts.py::test_healpix_fft_cuda_transforms[8966433580120847635-16] Fatal Python error: Segmentation fault

Thread 0x0000ffff9a0152e0 (most recent call first):
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/compiler.py", line 362 in backend_compile_and_load
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/profiler.py", line 359 in wrapper
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/compiler.py", line 746 in _compile_and_write_cache
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/compiler.py", line 478 in compile_or_get_cached
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 2854 in _cached_compilation
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 3073 in from_hlo
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 2527 in compile
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 1620 in _pjit_call_impl_python
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 146 in _python_pjit_helper
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/pjit.py", line 264 in cache_miss
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 195 in reraise_with_filtered_traceback
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/numpy/ufunc_api.py", line 182 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py", line 608 in deferring_binary_op
  File "/home/ccaemgr/projects/s2fft/tests/test_healpix_ffts.py", line 118 in test_healpix_fft_cuda_transforms
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/python.py", line 166 in pytest_pyfunc_call
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_callers.py", line 121 in _multicall
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_hooks.py", line 512 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/python.py", line 1720 in runtest
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 179 in pytest_runtest_call
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_callers.py", line 121 in _multicall
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_hooks.py", line 512 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 245 in <lambda>
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 353 in from_call
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 244 in call_and_report
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 137 in runtestprotocol
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/runner.py", line 118 in pytest_runtest_protocol
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_callers.py", line 121 in _multicall
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_hooks.py", line 512 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/main.py", line 396 in pytest_runtestloop
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_callers.py", line 121 in _multicall
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_hooks.py", line 512 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/main.py", line 372 in _main
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/main.py", line 318 in wrap_session
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/main.py", line 365 in pytest_cmdline_main
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_callers.py", line 121 in _multicall
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/pluggy/_hooks.py", line 512 in __call__
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/config/__init__.py", line 199 in main
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/_pytest/config/__init__.py", line 223 in console_main
  File "/home/ccaemgr/projects/s2fft/test-venv/bin/pytest", line 8 in <module>

Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, astropy.io.fits._utils, astropy.io.fits.hdu.compressed._compression, healpy._healpy_sph_transform_lib, healpy._sphtools, healpy._pixelfunc, healpy._query_disc, erfa.ufunc, astropy.time._parse_times, astropy.table._column_mixins, astropy.table._np_utils, yaml._yaml, astropy.io.ascii.cparser, astropy.utils.xml._iterparser, astropy.io.votable.tablewriter, healpy._masktools, healpy._hotspots, healpy._line_integral_convolution, jaxlib.cpu_feature_guard, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._generator, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special (total: 40)
Segmentation fault (core dumped)
tests/test_healpix_ffts.py::test_healpix_ifft_cuda_transforms[8966433580120847635-8] Fatal Python error: Segmentation fault

Thread 0x0000ffff8d4c52e0 (most recent call first):
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/lax/slicing.py", line 1823 in _is_sorted
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/lax/slicing.py", line 2533 in _scatter_shape_rule
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/lax/utils.py", line 111 in call_shape_dtype_sharding_rule
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/lax/utils.py", line 168 in standard_abstract_eval
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/core.py", line 702 in abstract_eval_
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 1946 in _cached_abstract_eval
  File "/home/ccaemgr/projects/s2fft/test-venv/lib/python3.13/site-packages/jax/_src/util.py", line 444 in cache_miss
  File Segmentation fault (core dumped)

This does not appear to consistently happen in the same test but it does seem to only happen in the transforms labelled tests.

From some testing I am also getting odd behaviour with the forward transforms, specifically something in the codepath seems to be mutating the input (signal) array f and in a way that is sidestepping some of JAX's internal caching mechanisms, so that the value reported to be in the array using its __repr__ method does not match with the actual array data (looking for the code path for __repr__ this accesses the _value property which itself uses a cached attribute _npy_value which appears to be getting out of sync with the actual array data in the _arrays property).

For example the code snippet

import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import numpy as np

seed = 1234
L = 3
reality = True

rng = np.random.default_rng(seed)
kwargs = {"L": L, "reality": reality, "nside": L // 2, "sampling": "healpix", "method": "jax_cuda"}

flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs)
f_copy = f.copy()

print(f"Before forward call {np.allclose(f, f_copy) = }")

print(f"Before forward call {f - f_copy = }")

flm_recov = s2fft.forward(f, **kwargs)

print(f"After first forward call {np.allclose(f, f_copy) = }")
print(f"After first forward call {f - f_copy = }")

flm_recov_2 = s2fft.forward(f, **kwargs)

print(f"After second forward call {np.allclose(f, f_copy) = }")
print(f"After second forward call {f - f_copy = }")

print(f"{np.allclose(flm_recov, flm_recov_2) = }"

outputs

Before forward call np.allclose(f, f_copy) = True
Before forward call f - f_copy = Array([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j,
       0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=complex128)
After first forward call np.allclose(f, f_copy) = True
After first forward call f - f_copy = Array([-1.22631372+0.j        , -0.53001587-0.99845605j,
        1.94917465+0.j        ,  0.46844017+0.99845605j,
       -4.2033667 +0.j        ,  3.74754851+2.04369928j,
        7.33217726+0.j        ,  1.70384924-2.04369928j,
       -1.26285115+0.j        ,  1.29868071+2.93037909j,
        1.90219375+0.j        , -1.63169838-2.93037909j],      dtype=complex128)
After second forward call np.allclose(f, f_copy) = True
After second forward call f - f_copy = Array([ -0.56502848+0.j,  -5.70241634+0.j,   2.73361129+0.j,
        -0.7101361 +0.j,   4.37684161+0.j,  -3.7005969 +0.j,
         5.00959007+0.j, -13.91909328+0.j,  -0.95652623+0.j,
         3.99439399+0.j,   2.87455401+0.j, -10.65750146+0.j],      dtype=complex128)
np.allclose(flm_recov, flm_recov_2) = False

Notice that while np.allclose(f, f_copy) reports True after the first s2fft.forward call, computing the difference f - f_copy shows the underlying array data differs. I suspect this is as the call to np.allclose is using the cached _npy_value attributes which are out of sync with the data. We also get a difference in the output of the two successive s2fft.forward calls with the same input array f, which I think is reflecting that the call to s2fft.forward is mutating its input argument even though this doesn't show up in the string representation or when testing equality with a copy with np.allclose.

I am going to continue to try to debug where this memory issue is arising and see if it relates to the segfaults but thought I'd give a (long overdue) update first just saw you know this is on my radar!

@matt-graham
Copy link
Collaborator

Looking at the code in a bit more detail, I think at least part of what is causing this issue is that it appears that healpix_forward in lib/src/extensions.cc performs the FFT operations in place on the passed input buffer. Concentrating on the non-batched case for simplicity

s2fft/lib/src/extensions.cc

Lines 127 to 142 in 3bcb69a

fft_complex_type* data_c = reinterpret_cast<fft_complex_type*>(input.typed_data());
fft_complex_type* out_c = reinterpret_cast<fft_complex_type*>(output->typed_data());
fft_complex_type* workspace_c = reinterpret_cast<fft_complex_type*>(workspace->typed_data());
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
// Step 2m: Launch the forward transform.
executor->Forward(descriptor, stream, data_c, workspace_c);
// Step 2n: Launch spectral extension kernel with shift and normalization.
int kernel_norm = (descriptor.norm == s2fftKernels::fft_norm::FORWARD) ? 0
: (descriptor.norm == s2fftKernels::fft_norm::ORTHO) ? 1
: 2;
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
descriptor.harmonic_band_limit, descriptor.shift, kernel_norm,
stream);

data_c is a reinterpreted cast view of the input buffer data and is passed to executor->Forward along with the workspace buffer view workspace_c. The logic of the implementation of Forward seems to be to perform the cuFFT operations in place on slices of the data_c argument:

s2fft/lib/src/s2fft.cu

Lines 130 to 131 in 3bcb69a

CUFFT_CALL(
cufftXtExec(m_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, DIRECTION));

with pointers to offsets within data_c array passed as both input and ouput arguments to cufftXtExec. data_c after the call to Forward is then passed directly to the launch_spectral_extension kernel which writes output to a separate out_c output buffer view.

Supporting this hypothesis, if we inspect the values in the mutated underlying _arrays property of an array passed to s2fft.forward with method="jax_cuda", these do seem to correspond to FFTs of slices of the original input array values corresponding to the HEALPix rings:

import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import numpy as np

from s2fft.sampling import s2_samples

seed = 1234
L = 4
sampling = "healpix"
method = "jax_cuda"
nside = L // 2
reality = True

rng = np.random.default_rng(seed)
kwargs = {"L": L, "reality": reality, "nside": nside, "sampling": sampling}

flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs, method="jax")

f_copy = f.copy()

# Required to ensure a cached _npy_value is recorded before mutation
print(f"Before forward call: {np.all(f_copy == f._value) = }")
print(f"Before forward call: {np.all(f_copy == f._arrays[0]) = }")

s2fft.forward(f, **kwargs, method="jax_cuda")

print(f"After forward call: {np.all(f_copy == f._value) = }")
print(f"After forward call: {np.all(f_copy == f._arrays[0]) = }")

index = 0
for t in range(s2_samples.ntheta(L, sampling, nside)):
    nphi = s2_samples.nphi_ring(t, nside)
    fft_slice_matches = np.allclose(
        jax.numpy.fft.fft(f._value[index : index + nphi], norm="backward"),
        f._arrays[0][index : index + nphi]
    )
    print(f"Ring {t}, {fft_slice_matches = }")
    index += nphi

outputs

Before forward call: np.all(f_copy == f._value) = Array(True, dtype=bool)
Before forward call: np.all(f_copy == f._arrays[0]) = Array(True, dtype=bool)
After forward call: np.all(f_copy == f._value) = Array(True, dtype=bool)
After forward call: np.all(f_copy == f._arrays[0]) = Array(False, dtype=bool)
Ring 0, fft_slice_matches = True
Ring 1, fft_slice_matches = True
Ring 2, fft_slice_matches = True
Ring 3, fft_slice_matches = True
Ring 4, fft_slice_matches = True
Ring 5, fft_slice_matches = True
Ring 6, fft_slice_matches = True

There seems to be something more going on that just this however, as if we change the lines

flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs, method="jax")

to

f = jax.numpy.asarray(rng.standard_normal(s2_samples.f_shape(L, sampling, nside)))

so removing the s2fft.inverse call (with method="jax"), we instead get output

Before forward call: np.all(f_copy == f._value) = Array(True, dtype=bool)
Before forward call: np.all(f_copy == f._arrays[0]) = Array(True, dtype=bool)
After forward call: np.all(f_copy == f._value) = Array(True, dtype=bool)
After forward call: np.all(f_copy == f._arrays[0]) = Array(True, dtype=bool)
Ring 0, fft_slice_matches = False
Ring 1, fft_slice_matches = False
Ring 2, fft_slice_matches = False
Ring 3, fft_slice_matches = False
Ring 4, fft_slice_matches = False
Ring 5, fft_slice_matches = False
Ring 6, fft_slice_matches = False

suggesting in this case s2fft.forward does not mutate the underlying data in f 😕

@ASKabalan
Copy link
Collaborator Author

I just deactivated the "forward" part like so

// Step 2j: Non-batched case.
// Step 2k: Get device pointers for data, output, and workspace.
fft_complex_type* data_c = reinterpret_cast<fft_complex_type*>(input.typed_data());
fft_complex_type* out_c = reinterpret_cast<fft_complex_type*>(output->typed_data());
fft_complex_type* workspace_c = reinterpret_cast<fft_complex_type*>(workspace->typed_data());

// Step 2l: Get or create an s2fftExec instance from the PlanCache.
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
// Step 2m: Launch the forward transform.
std::cout << "Commenting forward call" << std::endl;
// executor->Forward(descriptor, stream, data_c, workspace_c);
// Step 2n: Launch spectral extension kernel with shift and normalization.
int kernel_norm = (descriptor.norm == s2fftKernels::fft_norm::FORWARD) ? 0
                : (descriptor.norm == s2fftKernels::fft_norm::ORTHO) ? 1
                                                                : 2;
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
                                descriptor.harmonic_band_limit, descriptor.shift, kernel_norm,
                                stream);
return ffi::Error::Success();

And now the input is not being corrupted so your intuition is correct

So now the ordeal is that I can obviously patch this but it requires me to copy the Input map (or make a scratch copy)
what I did for the 3D FFTs was to 'alias' the input with the output which tells XLA to 'try' to reuse the input buffer

This is not possible here since the output does not have the same size as the input.

I have to study more the FFI literature for this but simplest way to do this is to use the scratch allocator to make temp memory

OR

have an option where the user can donate the buffer or not.

I will look in the FFI docs of there is something I can do for this without duplicating in memory

In all case going from complex to (un)folded has to be done out of memory if we don't want to use complicated synch in CUDA

@matt-graham
Copy link
Collaborator

I just deactivated the "forward" part like so

...

And now the input is not being corrupted so your intuition is correct

Thanks for checking this @ASKabalan.

So now the ordeal is that I can obviously patch this but it requires me to copy the Input map (or make a scratch copy) what I did for the 3D FFTs was to 'alias' the input with the output which tells XLA to 'try' to reuse the input buffer

This is not possible here since the output does not have the same size as the input.

I have to study more the FFI literature for this but simplest way to do this is to use the scratch allocator to make temp memory

OR

have an option where the user can donate the buffer or not.

For the latter option: if we had an additional output corresponding to the intermediate per-ring FFTs applied to input for the underlying _healpix_fft_cuda_primitive, analogous to the current temporary workspace output, then I guess we could provide a wrapped 'in-place' version of this primitive with jit applied with donate_argnums used to indicate to donate the input buffer to the output buffer for intermediate FFT output and then dispatch to relevant jitted with donate_argnums or just jitted functions in exposed healpix_fft_cuda function based on an inplace argument or similar passed to user? Is this similar to what you are suggesting here? If so this seems like a good option to me.

multiple_results=False,
multiple_results=True, # Indicates that the primitive returns multiple outputs.
abstract_evaluation=_healpix_fft_cuda_abstract,
lowering_per_platform={None: _healpix_fft_cuda_lowering},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lowering_per_platform={None: _healpix_fft_cuda_lowering},
lowering_per_platform={"cuda": _healpix_fft_cuda_lowering},

If we changed this would this give us an earlier / more informative error message if user tries to lower on a none CUDA device?

Comment on lines +117 to +130
MSE = jnp.mean(
(jax.vmap(healpix_jax)(f_stacked) - jax.vmap(healpix_cuda)(f_stacked)) ** 2
)
assert MSE < 1e-14
# test jacfwd
MSE = jnp.mean(
(jax.jacfwd(healpix_jax)(f.real) - jax.jacfwd(healpix_cuda)(f.real)) ** 2
)
assert MSE < 1e-14
# test jacrev
MSE = jnp.mean(
(jax.jacrev(healpix_jax)(f.real) - jax.jacrev(healpix_cuda)(f.real)) ** 2
)
assert MSE < 1e-14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally these should use assert_allclose similar to other tests as this gives a more informative error message with summary of differences if test fails and allows checks for differences on a per element level rather than overall norm.

Comment on lines 156 to 177
# Test VMAP
MSE = jnp.mean(
(
jax.vmap(healpix_inv_jax)(ftm_stacked)
- jax.vmap(healpix_inv_cuda)(ftm_stacked)
)
** 2
)
assert MSE < 1e-14
# test jacfwd
MSE = jnp.mean(
(jax.jacfwd(healpix_inv_jax)(ftm.real) - jax.jacfwd(healpix_inv_cuda)(ftm.real))
** 2
)
assert MSE < 1e-14

# test jacrev
MSE = jnp.mean(
(jax.jacrev(healpix_inv_jax)(ftm.real) - jax.jacrev(healpix_inv_cuda)(ftm.real))
** 2
)
assert MSE < 1e-14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to comment above would be better if these checks used assert_allclose with appropritate tolerances to give more informative error messages on failures.

@ASKabalan
Copy link
Collaborator Author

For the latter option: if we had an additional output corresponding to the intermediate per-ring FFTs applied to input for the underlying _healpix_fft_cuda_primitive, analogous to the current temporary workspace output, then I guess we could provide a wrapped 'in-place' version of this primitive with jit applied with donate_argnums used to indicate to donate the input buffer to the output buffer for intermediate FFT output and then dispatch to relevant jitted with donate_argnums or just jitted functions in exposed healpix_fft_cuda function based on an inplace argument or similar passed to user? Is this similar to what you are suggesting here? If so this seems like a good option to me.

Yeah this is very smart indeed
Actually in this case I don't need to use donate buffers because the buffer will automatically be donated if the user doesn't use .
In XLA I can alias output and input to instruct XLA to do this

I will implement this when I get the chance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Check autodiff and batching support for healpix_fft_cuda primitive and add if needed

3 participants