-
Notifications
You must be signed in to change notification settings - Fork 13
Description
The s2fft.precompute_transforms.construct.spin_spherical_kernel function constructs an array of dimension
L = 1024; spin = 0; sampling = "mw"; reality = True; recursion = "auto"; method = "jax"
with the output array being of shape (2 * L + 1, L, L) = (2049, 1024, 1024) corresponding to 2049 × 1024 × 1024 × 8 = 16.0 GiB for float64 data type on a NVIDIA GH200 GPU with 96GiB of device memory.
The OOM error flags the line
| dl = dl.at[jnp.where(dl != dl)].set(0) |
as being where device memory is exhausted with full traceback
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "s2fft/s2fft/precompute_transforms/construct.py", line 242, in spin_spherical_kernel_jax
dl = dl.at[jnp.where(dl != dl)].set(0)
^^^^^^^^^^^^^^^^^^^
File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2810, in where
return nonzero(condition, size=size, fill_value=fill_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3750, in nonzero
bincount(reductions.cumsum(mask), length=calculated_size))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2981, in bincount
return array_creation.zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop')
^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359730176 bytes.
34359730176bytes ≈ 32GiB so it looks like this is trying to allocate arrays that are roughly double in size of the output array.
As far as I can tell this line is setting all NaN entries of dl (which I think would be the only elements for which dl != dl will evaluate to True) to zero. The numpy.where call maps the boolean array dl != dl to a tuple of index arrays for which the condition is True. This is inefficient as we can use boolean indexing directly, but I also think we may be able to avoid separately materializing the intermediate boolean array altogether and potentially have the update applied in-place with some refactoring.