-
Dear All, sorry for the stupid question, but I wasn't able to find an answer in the JAX documentation. In my current code, I use a jitted function in a Python loop. Unfortunately, it leads to excessive RAM usage & it forces Colab instances to crash. I conjecture that his problem is of the same kind as one described in the following discussion: I tried to use the solution suggested by @mattjj there. `from jax.interpreters import xla xla._xla_callable.cache_clear()` Unfortunately, it looks like there were some breaking changes in jax.interpreters.xla, because when I call
I receive the following error message: `AttributeError Traceback (most recent call last) AttributeError: module 'jax.interpreters.xla' has no attribute '_xla_callable'` I would like to ask, how to access cache_clear() function in the current version of JAX/XLA. Thanks in advance, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
If you're interested in clearing the cache for a single jit-compiled function, you can use the from jax import jit
@jit
def f(x):
return x + 1
print(f._cache_size())
# 0
f(1.0)
print(f._cache_size())
# 1
f(1)
print(f._cache_size())
# 2
f._clear_cache()
print(f._cache_size())
# 0 If you're interested in clearing the cache more globally, I'd suggest following the feature request in #10828. |
Beta Was this translation helpful? Give feedback.
If you're interested in clearing the cache for a single jit-compiled function, you can use the
_clear_cache()
method of the function. For example:If you're interested in clearing the cache more globally, I'd suggest following the feature request in #10828.