-
From the activity monitor, the process's memory keeps growing during the for loop. How to deal with it? import jax
import jax.numpy as jnp
import numpy as np
from time import sleep
from functools import partial
class A:
def __init__(self, name):
self.name = name
self.X = np.ones((5000, 5000))
@partial(jax.jit, static_argnums=(0,))
def f(self, Y):
return jnp.sum(Y**2)
Y = np.ones(10)
for i in range(100):
a = A(i)
a.f(Y)
sleep(2) |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 15 replies
-
Hi - thanks for the question! In some senses this is intended behavior, in that the JAX jit cache uses strong references to track static arguments, so that if an argument goes out of scope the cache entry is not discarded. You can see why this would be useful, because, e.g. you wouldn't want to have to recompile a function just because you used a non-global string value as a static argument! The side effect, then, is that if you have very large temporary static arguments, a strong reference will be created within the jit cache that will keep them in scope. This isn't really a leak per-se – the object is deliberately referenced in order to keep cached versions of compiled functions and prevent having to re-compile. You could address this by manually clearing the JIT cache of the function, in which case the temporary objects are garbage collected and the memory use remains constant: import os, psutil
for i in range(41):
A.f.clear_cache()
a = A(i)
a.f(Y)
if i % 10 == 0:
process = psutil.Process(os.getpid())
print(f"Iteration {i}: {process.memory_info().rss} bytes")
# Iteration 0: 471412736 bytes
# Iteration 10: 873562112 bytes
# Iteration 20: 873660416 bytes
# Iteration 30: 873660416 bytes
# Iteration 40: 873672704 bytes The jit cache is an LRU cache of a fixed size, so eventually unused static arguments will be discarded, but if your static arguments are large, you'll probably run out of memory before that point. A better solution here would be to use static argnums as JAX intends: i.e. pass small hashable truly static values. Note that your use of static_argnums here goes against the recommendations in the
Your class does not implement If you're working with classes that contain arrays, marking them as static is bad practice as it can easily result in surprises; here's a quick example: class A:
def __init__(self):
self.x = jnp.arange(5)
@partial(jax.jit, static_argnums=0)
def f(self):
return self.x.sum()
a = A()
print(a.f())
# 10
a.x += 1 # mutation that doesn't change the hash
print(a.f())
# 10 The class has violated the contract of A better approach for this kind of thing would be to register your class as a custom pytree, which is effectively a way to tell jax which aspects of the class are static, and which are dynamic. For your class it might look like this: from jax.tree_util import register_pytree_node
class A:
def __init__(self, name):
self.name = name
self.X = np.ones((5000, 5000))
@jax.jit # note: no static argnums
def f(self, Y):
return jnp.sum(Y**2)
def _flatten_A(a):
children = (a.X,) # non-static values
aux_data = (a.name,) # static values
return (children, aux_data)
def _unflatten_A(aux_data, children):
X, = children
name, = aux_data
a = A(name)
a.X = X
return A
register_pytree_node(A, _flatten_A, _unflatten_A)
a = A(1)
print(a.f(Y))
# 10.0 here Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
For context, cross-referencing patrick-kidger/diffrax#142 here as another example of the same issue. (Where the solution is the same -- clear the JIT cache.) In addition, @pipme - if you want to do "JAX with classes", and in particular have the register-as-custom-pytree handled for you, then have a look at Equinox. |
Beta Was this translation helpful? Give feedback.
-
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
class A:
def __init__(self, name):
self.name = name
self.X = jnp.ones((5000, 5000))
@jax.jit # note: no static argnums
def f(self, Y):
return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
def _flatten_A(a):
children = (a.X,) # non-static values
aux_data = (a.name,) # static values
return (children, aux_data)
def _unflatten_A(aux_data, children):
(X,) = children
(name,) = aux_data
a = A(name)
a.X = X
return a
register_pytree_node(A, _flatten_A, _unflatten_A)
a = A(1)
for i in range(1,5):
a.f(jnp.ones(i))
print(A.f._cache_size()) # 4
del a
print(A.f._cache_size()) # 4 @jakevdp After registering the class as custom pytree, I would like the jit cache to be automatically cleared while it seems not. I am aware of #7930 (comment). I need to create a lot of objects like Is calling |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! In some senses this is intended behavior, in that the JAX jit cache uses strong references to track static arguments, so that if an argument goes out of scope the cache entry is not discarded. You can see why this would be useful, because, e.g. you wouldn't want to have to recompile a function just because you used a non-global string value as a static argument!
The side effect, then, is that if you have very large temporary static arguments, a strong reference will be created within the jit cache that will keep them in scope. This isn't really a leak per-se – the object is deliberately referenced in order to keep cached versions of compiled functions and pre…