Conversation
|
Hello @matt-graham @jasonmcewen @CosmoMatt Just a quick PR to wrap up a few stuff
And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work) There is an issue with building pyssht not sure that this is my fault I will check the failing worflows when I get the chance, but in the meantime a review is appreciated |
matt-graham
left a comment
There was a problem hiding this comment.
Hello @matt-graham @jasonmcewen @CosmoMatt
Just a quick PR to wrap up a few stuff
1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html) 2. Added a vmap implementation of the cuda primitive 3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell) 4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100 5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section) 6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)
There is an issue with building pyssht not sure that this is my fault
I will check the failing worflows when I get the chance, but in the meantime a review is appreciated
Hi @ASKabalan, sorry for the delay in getting back to you.
This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.
With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.
I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.
tests/test_healpix_ffts.py
Outdated
| flm_hp = samples.flm_2d_to_hp(flm, L) | ||
| f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
There was a problem hiding this comment.
I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix") here instead of going via healpy? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.
Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
|
I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with consistently the tests hang when trying to run the first Running just the IFFT tests with the tests for both set of test parameters pass. Trying to dig into this a bit, running the following locally import healpy
import jax
import s2fft
import numpy
jax.config.update("jax_enable_x64", True)
seed = 20250416
nside = 4
L = 2 * nside
reality = False
rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready()raises an error so it looks like there is some memory addressing issue somewhere in the |
|
Thank you I was able to reproduce with 12.4.1 but not locally with 12.4 I will take a look |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #290 +/- ##
==========================================
- Coverage 96.50% 96.14% -0.36%
==========================================
Files 32 32
Lines 3434 3453 +19
==========================================
+ Hits 3314 3320 +6
- Misses 120 133 +13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@matt-graham Hey I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway |
Hi @ASKabalan. Do you mean
Yes agreed we should drop Python 3.8 from test matrix - we have an open pull request #305 to update to only supporting Python 3.11+ but this is partially blocked by #212 as the tests currently exit with fatal errors when running on MacOS / Python 3.9+ due to an incompatibility between the OpenMP runtime's the MacOS wheels for |
Add comprehensive documentation and fix dependency issues for CUDA FFT integration. This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in to ensure proper installation and functionality. Key changes include: - no more CUDA Malloc .. all memory is allocated in Python by XLA - Added detailed docstrings to C++ header files - Enhanced inline comments in C++ source files to explain complex logic and algorithms. - Updated to relax JAX version dependency, resolving installation issues. - Refined docstrings and comments in Python files for clarity and consistency. - Cleaned up debug print statements
* Update Python version used in docs workflow * Trigger docs workflow on pull-requests * Deploy only on push to main
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.0.0 to 3.0.1. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](pypa/cibuildwheel@v3.0.0...v3.0.1) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-version: 3.0.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.0.1 to 3.1.3. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](pypa/cibuildwheel@v3.0.1...v3.1.3) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-version: 3.1.3 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Update custom_ops.py Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai. * Update custom_ops.py Removed commented lines for linting purposes * Removing now unused imports --------- Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
Bumps [actions/checkout](https://github.com/actions/checkout) from 4.2.2 to 5.0.0. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](actions/checkout@v4.2.2...v5.0.0) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: 5.0.0 dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4 to 5. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](actions/download-artifact@v4...v5) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](pypa/cibuildwheel@v3.1.3...v3.1.4) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-version: 3.1.4 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
|
@matt-graham |
|
@matt-graham I think this is now good to go Otherwise I made a much more robust out of place shift in the CUDA kernels by fusing two of the kernels I had earlier In short there is no longer need for synchronization |
Hi @ASKabalan thanks for your updates, and apologies for being very slow in getting to reviewing this. Generally the changes look good, but there still seems to be some memory related issues occuring for me in testing. Specifically: When running This does not appear to consistently happen in the same test but it does seem to only happen in the From some testing I am also getting odd behaviour with the forward transforms, specifically something in the codepath seems to be mutating the input (signal) array For example the code snippet import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import numpy as np
seed = 1234
L = 3
reality = True
rng = np.random.default_rng(seed)
kwargs = {"L": L, "reality": reality, "nside": L // 2, "sampling": "healpix", "method": "jax_cuda"}
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs)
f_copy = f.copy()
print(f"Before forward call {np.allclose(f, f_copy) = }")
print(f"Before forward call {f - f_copy = }")
flm_recov = s2fft.forward(f, **kwargs)
print(f"After first forward call {np.allclose(f, f_copy) = }")
print(f"After first forward call {f - f_copy = }")
flm_recov_2 = s2fft.forward(f, **kwargs)
print(f"After second forward call {np.allclose(f, f_copy) = }")
print(f"After second forward call {f - f_copy = }")
print(f"{np.allclose(flm_recov, flm_recov_2) = }"outputs Notice that while I am going to continue to try to debug where this memory issue is arising and see if it relates to the segfaults but thought I'd give a (long overdue) update first just saw you know this is on my radar! |
|
Looking at the code in a bit more detail, I think at least part of what is causing this issue is that it appears that Lines 127 to 142 in 3bcb69a
Lines 130 to 131 in 3bcb69a with pointers to offsets within Supporting this hypothesis, if we inspect the values in the mutated underlying import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import numpy as np
from s2fft.sampling import s2_samples
seed = 1234
L = 4
sampling = "healpix"
method = "jax_cuda"
nside = L // 2
reality = True
rng = np.random.default_rng(seed)
kwargs = {"L": L, "reality": reality, "nside": nside, "sampling": sampling}
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs, method="jax")
f_copy = f.copy()
# Required to ensure a cached _npy_value is recorded before mutation
print(f"Before forward call: {np.all(f_copy == f._value) = }")
print(f"Before forward call: {np.all(f_copy == f._arrays[0]) = }")
s2fft.forward(f, **kwargs, method="jax_cuda")
print(f"After forward call: {np.all(f_copy == f._value) = }")
print(f"After forward call: {np.all(f_copy == f._arrays[0]) = }")
index = 0
for t in range(s2_samples.ntheta(L, sampling, nside)):
nphi = s2_samples.nphi_ring(t, nside)
fft_slice_matches = np.allclose(
jax.numpy.fft.fft(f._value[index : index + nphi], norm="backward"),
f._arrays[0][index : index + nphi]
)
print(f"Ring {t}, {fft_slice_matches = }")
index += nphioutputs There seems to be something more going on that just this however, as if we change the lines flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
f = s2fft.inverse(flm, **kwargs, method="jax")to f = jax.numpy.asarray(rng.standard_normal(s2_samples.f_shape(L, sampling, nside)))so removing the suggesting in this case |
|
I just deactivated the "forward" part like so // Step 2j: Non-batched case.
// Step 2k: Get device pointers for data, output, and workspace.
fft_complex_type* data_c = reinterpret_cast<fft_complex_type*>(input.typed_data());
fft_complex_type* out_c = reinterpret_cast<fft_complex_type*>(output->typed_data());
fft_complex_type* workspace_c = reinterpret_cast<fft_complex_type*>(workspace->typed_data());
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
// Step 2m: Launch the forward transform.
std::cout << "Commenting forward call" << std::endl;
// executor->Forward(descriptor, stream, data_c, workspace_c);
// Step 2n: Launch spectral extension kernel with shift and normalization.
int kernel_norm = (descriptor.norm == s2fftKernels::fft_norm::FORWARD) ? 0
: (descriptor.norm == s2fftKernels::fft_norm::ORTHO) ? 1
: 2;
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
descriptor.harmonic_band_limit, descriptor.shift, kernel_norm,
stream);
return ffi::Error::Success();And now the input is not being corrupted so your intuition is correct So now the ordeal is that I can obviously patch this but it requires me to copy the Input map (or make a scratch copy) This is not possible here since the output does not have the same size as the input. I have to study more the FFI literature for this but simplest way to do this is to use the scratch allocator to make temp memory OR have an option where the user can donate the buffer or not. I will look in the FFI docs of there is something I can do for this without duplicating in memory In all case going from complex to (un)folded has to be done out of memory if we don't want to use complicated synch in CUDA |
Thanks for checking this @ASKabalan.
For the latter option: if we had an additional output corresponding to the intermediate per-ring FFTs applied to input for the underlying |
| multiple_results=False, | ||
| multiple_results=True, # Indicates that the primitive returns multiple outputs. | ||
| abstract_evaluation=_healpix_fft_cuda_abstract, | ||
| lowering_per_platform={None: _healpix_fft_cuda_lowering}, |
There was a problem hiding this comment.
| lowering_per_platform={None: _healpix_fft_cuda_lowering}, | |
| lowering_per_platform={"cuda": _healpix_fft_cuda_lowering}, |
If we changed this would this give us an earlier / more informative error message if user tries to lower on a none CUDA device?
| MSE = jnp.mean( | ||
| (jax.vmap(healpix_jax)(f_stacked) - jax.vmap(healpix_cuda)(f_stacked)) ** 2 | ||
| ) | ||
| assert MSE < 1e-14 | ||
| # test jacfwd | ||
| MSE = jnp.mean( | ||
| (jax.jacfwd(healpix_jax)(f.real) - jax.jacfwd(healpix_cuda)(f.real)) ** 2 | ||
| ) | ||
| assert MSE < 1e-14 | ||
| # test jacrev | ||
| MSE = jnp.mean( | ||
| (jax.jacrev(healpix_jax)(f.real) - jax.jacrev(healpix_cuda)(f.real)) ** 2 | ||
| ) | ||
| assert MSE < 1e-14 |
There was a problem hiding this comment.
Ideally these should use assert_allclose similar to other tests as this gives a more informative error message with summary of differences if test fails and allows checks for differences on a per element level rather than overall norm.
| # Test VMAP | ||
| MSE = jnp.mean( | ||
| ( | ||
| jax.vmap(healpix_inv_jax)(ftm_stacked) | ||
| - jax.vmap(healpix_inv_cuda)(ftm_stacked) | ||
| ) | ||
| ** 2 | ||
| ) | ||
| assert MSE < 1e-14 | ||
| # test jacfwd | ||
| MSE = jnp.mean( | ||
| (jax.jacfwd(healpix_inv_jax)(ftm.real) - jax.jacfwd(healpix_inv_cuda)(ftm.real)) | ||
| ** 2 | ||
| ) | ||
| assert MSE < 1e-14 | ||
|
|
||
| # test jacrev | ||
| MSE = jnp.mean( | ||
| (jax.jacrev(healpix_inv_jax)(ftm.real) - jax.jacrev(healpix_inv_cuda)(ftm.real)) | ||
| ** 2 | ||
| ) | ||
| assert MSE < 1e-14 |
There was a problem hiding this comment.
Similar to comment above would be better if these checks used assert_allclose with appropritate tolerances to give more informative error messages on failures.
Yeah this is very smart indeed I will implement this when I get the chance |
Adding a few updates
A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform
I will be doing that next