Set GPU memory allocation via Python (programatically) #6102
Unanswered
fabiannagel
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! I've been benchmarking JAX-MD and noticed some strange runtime variance that I have tracked down to JAX GPU memory allocation behavior. For that reason, I want to run JAX in various VRAM allocation modes. I know it works in a terminal with
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
but I'd prefer to do it programatically via Python.Updating the OS environment variables is not the issue but I noticed that the JAX module needs to be reloaded in order to take notice of it. Applying the sledgehammer via
importlib.reload()
causes some issues with pickling, as that's the way I save my benchmark results.Therefore my question: Is there a "nice" and simple way to get JAX to reload XLA environment variables without doing a full-blown re-import of the entire module? Thanks for your help :)
Beta Was this translation helpful? Give feedback.
All reactions