Skip to content

Problem with jax_cosmo caching system #140

@ASKabalan

Description

@ASKabalan

@EiffL
I created this 'short' snippet to reproduce the leak error that I systematically get when I am doing forward modeling with jaxpm

import jax
import jax.numpy as jnp
from diffrax import Euler, ODETerm, diffeqsolve

import jax_cosmo as jc

cosmo = jc.Planck15()


def ode_with_chi(t, y, cosmo):
    distance = jc.background.radial_comoving_distance(cosmo, t).squeeze()
    return jnp.squeeze(distance) * 0.1 + y


def ode_with_gp(t, y, cosmo):
    growth_factor = jc.background.growth_factor(cosmo, jnp.array([t]))
    return jnp.squeeze(growth_factor) * 0.1 + y


def ode_wrapper(cosmo):
    def fn(t, y, args):
        return ode_with_chi(t, y, cosmo)

    return fn


def integrate(cosmo, y, ode_term):
    result = diffeqsolve(
        ode_term,
        Euler(),
        t0=0.0,
        t1=1.0,
        dt0=0.1,
        y0=y,
        args=cosmo,  # This is the key part that causes the leak!
    )
    return result.ys


@jax.jit
def integrate_with_chi(cosmo):
    ode_term = ODETerm(ode_with_chi)
    return integrate(cosmo, jnp.array([1.0]), ode_term)


@jax.jit
def integrate_with_gp(cosmo):
    ode_term = ODETerm(ode_with_gp)
    return integrate(cosmo, jnp.array([1.0]), ode_term)


@jax.jit
def integrate_with_wrapper(cosmo):
    ode_term = ODETerm(ode_wrapper(cosmo))
    return integrate(cosmo, jnp.array([1.0]), ode_term)


@jax.jit
def integrate_with_chi_precompute(cosmo):
    distance = jc.background.radial_comoving_distance(cosmo, 0.1).squeeze()
    y0 = jnp.array([distance])
    ode_term = ODETerm(ode_with_chi)
    return integrate(cosmo, y0, ode_term)


@jax.jit
def integrate_with_chi_precompute_and_wrapper(cosmo):
    _ = jc.background.radial_comoving_distance(cosmo, 0.1).squeeze()
    y0 = jnp.array([1.0])
    ode_term = ODETerm(ode_wrapper(cosmo))
    return integrate(cosmo, y0, ode_term)


integrate_with_chi(cosmo)  # Leaks
integrate_with_gp(cosmo)  # Leaks
integrate_with_wrapper(cosmo)  # Leaks
integrate_with_chi_precompute(cosmo)  # Leaks
integrate_with_chi_precompute_and_wrapper(cosmo)  # DOES NOT leak

I think I can even reproduce without diffrax I think since the error here is caused by the fact that diffrax calls eval_shape which traces the caching twice

leak report


  File "/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/diffrax/_integrate.py", line 621, in loop
    _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
  File "/home/wassim/micromamba/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace DynamicJaxprTrace. Leaked tracer(s):

Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace>
The error occurred while tracing the function _fn at /home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/equinox/_eval_shape.py:31 for jit. 
<DynamicJaxprTracer 130500598565440> is referred to by <dict 130500600235072>['a']
<dict 130500600235072> is referred to by <dict 130500599962048>['background.radial_comoving_distance']
<dict 130500599962048> is referred to by <Cosmology 130500600156928>._workspace
<Cosmology 130500600156928> is referred to by <tuple 130500600588864>[5]

Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace>
The error occurred while tracing the function _fn at /home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/equinox/_eval_shape.py:31 for jit. 
<DynamicJaxprTracer 130500598573680> is referred to by <dict 130500600235072>['chi']
<dict 130500600235072> is referred to by <dict 130500599962048>['background.radial_comoving_distance']
<dict 130500599962048> is referred to by <Cosmology 130500600156928>._workspace
<Cosmology 130500600156928> is referred to by <tuple 130500600588864>[5]

My workaround is not good enough since cosmo is not passed as args .. knowing that diffrax guarentees that the adjoint of args is probably handled .. if it is passed around like the wrapper example I don't think the gradients are propagated correctly.
Also pre_computing is not very clean

tl;dr : I think that the jax_cosmo caching system can be more robust probably by using compiler directives such as jax.ensure_compile_time_eval

I will start making a PR when I get the chance.
In the mean time tell me what do you think

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions