Skip to content
Discussion options

You must be logged in to vote

If you're interested in clearing the cache for a single jit-compiled function, you can use the _clear_cache() method of the function. For example:

from jax import jit

@jit
def f(x):
  return x + 1

print(f._cache_size())
# 0

f(1.0)
print(f._cache_size())
# 1

f(1)
print(f._cache_size())
# 2

f._clear_cache()
print(f._cache_size())
# 0

If you're interested in clearing the cache more globally, I'd suggest following the feature request in #10828.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Honza9723
Comment options

Answer selected by Honza9723
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants