-
Notifications
You must be signed in to change notification settings - Fork 45
Description
@EiffL
I created this 'short' snippet to reproduce the leak error that I systematically get when I am doing forward modeling with jaxpm
import jax
import jax.numpy as jnp
from diffrax import Euler, ODETerm, diffeqsolve
import jax_cosmo as jc
cosmo = jc.Planck15()
def ode_with_chi(t, y, cosmo):
distance = jc.background.radial_comoving_distance(cosmo, t).squeeze()
return jnp.squeeze(distance) * 0.1 + y
def ode_with_gp(t, y, cosmo):
growth_factor = jc.background.growth_factor(cosmo, jnp.array([t]))
return jnp.squeeze(growth_factor) * 0.1 + y
def ode_wrapper(cosmo):
def fn(t, y, args):
return ode_with_chi(t, y, cosmo)
return fn
def integrate(cosmo, y, ode_term):
result = diffeqsolve(
ode_term,
Euler(),
t0=0.0,
t1=1.0,
dt0=0.1,
y0=y,
args=cosmo, # This is the key part that causes the leak!
)
return result.ys
@jax.jit
def integrate_with_chi(cosmo):
ode_term = ODETerm(ode_with_chi)
return integrate(cosmo, jnp.array([1.0]), ode_term)
@jax.jit
def integrate_with_gp(cosmo):
ode_term = ODETerm(ode_with_gp)
return integrate(cosmo, jnp.array([1.0]), ode_term)
@jax.jit
def integrate_with_wrapper(cosmo):
ode_term = ODETerm(ode_wrapper(cosmo))
return integrate(cosmo, jnp.array([1.0]), ode_term)
@jax.jit
def integrate_with_chi_precompute(cosmo):
distance = jc.background.radial_comoving_distance(cosmo, 0.1).squeeze()
y0 = jnp.array([distance])
ode_term = ODETerm(ode_with_chi)
return integrate(cosmo, y0, ode_term)
@jax.jit
def integrate_with_chi_precompute_and_wrapper(cosmo):
_ = jc.background.radial_comoving_distance(cosmo, 0.1).squeeze()
y0 = jnp.array([1.0])
ode_term = ODETerm(ode_wrapper(cosmo))
return integrate(cosmo, y0, ode_term)
integrate_with_chi(cosmo) # Leaks
integrate_with_gp(cosmo) # Leaks
integrate_with_wrapper(cosmo) # Leaks
integrate_with_chi_precompute(cosmo) # Leaks
integrate_with_chi_precompute_and_wrapper(cosmo) # DOES NOT leakI think I can even reproduce without diffrax I think since the error here is caused by the fact that diffrax calls eval_shape which traces the caching twice
leak report
File "/home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/diffrax/_integrate.py", line 621, in loop
_, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
File "/home/wassim/micromamba/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace DynamicJaxprTrace. Leaked tracer(s):
Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace>
The error occurred while tracing the function _fn at /home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/equinox/_eval_shape.py:31 for jit.
<DynamicJaxprTracer 130500598565440> is referred to by <dict 130500600235072>['a']
<dict 130500600235072> is referred to by <dict 130500599962048>['background.radial_comoving_distance']
<dict 130500599962048> is referred to by <Cosmology 130500600156928>._workspace
<Cosmology 130500600156928> is referred to by <tuple 130500600588864>[5]
Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace>
The error occurred while tracing the function _fn at /home/wassim/micromamba/envs/jax/lib/python3.10/site-packages/equinox/_eval_shape.py:31 for jit.
<DynamicJaxprTracer 130500598573680> is referred to by <dict 130500600235072>['chi']
<dict 130500600235072> is referred to by <dict 130500599962048>['background.radial_comoving_distance']
<dict 130500599962048> is referred to by <Cosmology 130500600156928>._workspace
<Cosmology 130500600156928> is referred to by <tuple 130500600588864>[5]
My workaround is not good enough since cosmo is not passed as args .. knowing that diffrax guarentees that the adjoint of args is probably handled .. if it is passed around like the wrapper example I don't think the gradients are propagated correctly.
Also pre_computing is not very clean
tl;dr : I think that the jax_cosmo caching system can be more robust probably by using compiler directives such as jax.ensure_compile_time_eval
I will start making a PR when I get the chance.
In the mean time tell me what do you think