Memory consumption comparison between Flax and PyTorch #20621
Unanswered
Sun-Xiaohui
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, JAX Team! Sorry for asking this question here, as it should be asked in the issue section of Flax, but I didn't get an answer there.
When I try to transfer Pytorch models to the Flax framework, I notice that Flax consumes more memory than PyTorch, even though I have set "XLA_PYTHON_CLIENT_PREALLOCATE=false". For example, a ResNet50 model in PyTorch consumes 4GB of GPU memory, while in Flax, it increases to 6GB. Is it working as expected? I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax?Thanks!
Beta Was this translation helpful? Give feedback.
All reactions