Skip to content

Computation of precompute spin spherical kernel allocates large intermediate arrays #319

@matt-graham

Description

@matt-graham

The s2fft.precompute_transforms.construct.spin_spherical_kernel function constructs an array of dimension $O(L^3)$. The currently implementation raises an out of memory (OOM) error when trying to construct the kernel for argument values

L = 1024; spin = 0; sampling = "mw"; reality = True; recursion = "auto"; method = "jax"

with the output array being of shape (2 * L + 1, L, L) = (2049, 1024, 1024) corresponding to 2049 × 1024 × 1024 × 8 = 16.0 GiB for float64 data type on a NVIDIA GH200 GPU with 96GiB of device memory.

The OOM error flags the line

dl = dl.at[jnp.where(dl != dl)].set(0)

as being where device memory is exhausted with full traceback

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "s2fft/s2fft/precompute_transforms/construct.py", line 242, in spin_spherical_kernel_jax
    dl = dl.at[jnp.where(dl != dl)].set(0)
               ^^^^^^^^^^^^^^^^^^^
  File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2810, in where
    return nonzero(condition, size=size, fill_value=fill_value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 3750, in nonzero
    bincount(reductions.cumsum(mask), length=calculated_size))
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s2fft/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2981, in bincount
    return array_creation.zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights, mode='drop')
                                                            ^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359730176 bytes.

34359730176bytes ≈ 32GiB so it looks like this is trying to allocate arrays that are roughly double in size of the output array.

As far as I can tell this line is setting all NaN entries of dl (which I think would be the only elements for which dl != dl will evaluate to True) to zero. The numpy.where call maps the boolean array dl != dl to a tuple of index arrays for which the condition is True. This is inefficient as we can use boolean indexing directly, but I also think we may be able to avoid separately materializing the intermediate boolean array altogether and potentially have the update applied in-place with some refactoring.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions