How to clear the jit cache for a function? Dropping function reference doesn't clear the cache? #7882
Unanswered
josephrocca
asked this question in
Q&A
Replies: 1 comment
-
The leaking behavior seems to be prevented by re-defining the function itself. So it's "fixed" if I change the code from this: def f(x):
return np.mean(np.dot(x, x.T))
size = 1002
for i in range(1000):
if i%100 == 0: print(f"\nROUND #{i}")
size -= 1
a = np.ones([size, size])
jit_f = None
time.sleep(0.01)
xla._xla_callable.cache_clear()
if i%100 == 0: print("MEM:", get_gpu_memory())
jit_f = jit(f)
out = jit_f(a)
if i%100 == 0: print(out)
if i%100 == 0: print("MEM:", get_gpu_memory()) to this: size = 1002
for i in range(1000):
if i%100 == 0: print(f"\nROUND #{i}")
size -= 1
a = np.ones([size, size])
jit_f = None
time.sleep(0.01)
xla._xla_callable.cache_clear()
if i%100 == 0: print("MEM:", get_gpu_memory())
def f(x):
return np.mean(np.dot(x, x.T))
jit_f = jit(f)
out = jit_f(a)
if i%100 == 0: print(out)
if i%100 == 0: print("MEM:", get_gpu_memory()) I created an issue about this here: #7930 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I'm a JAX newbie, and while debugging what appeared to be an memory leak in Huggingface's FlaxCLIPModel I realised that it was due to JAX's
jit
caching behavior.I was led to believe (by this comment from Matt) that by dropping the reference to a function, the
jit
cache for that function would be freed/cleared. But this doesn't seem to be the case.Here's a simple/minimal example that leaks about 1mb every 100 iterations:
And here's a closer-to-real-world example (from here) that leaks hundreds of megabytes per iteration (since the function being
jit
ed is much more complex):As you can see form those examples, I've tried
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
as mentioned here, andxla._xla_callable.cache_clear()
as mentioned here, but neither of those helped.Is this expected behavior? And if so, is there a way to clear the jit cache for a function? Thanks! 🙏
Beta Was this translation helpful? Give feedback.
All reactions