You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
And one of the problem that was encountered is that during the evaluation step or train step the process just hangs randomly (can happen in 1 hour of training or in 30 hours). This hang happens somewhere inside the pmapped eval/train function. And it happens only if there are more than 1 GPU. Furthermore, during this hang the GPU utilization (I am currenlty using A100, V100 or P5000) is zero, but the CPU utilization is almost 100 on all cores. And it hangs without any error or Exception.
Basically, if the eval/train step function is vmapped, just jitted or pmapped with 1 GPU, the steps works perfectly.
I have tried different experiments to find out the issue, by isolating code from any dataset or other libraries apart from jax, flax, ml_collections, scipy and tf. But nothing appears to help. There is an obvious connection between the error and multi-gpu configuration.
My environment is based on last version of jax and jaxlib:
jax 0.3.13
jaxlib 0.3.10+cuda11.cudnn82
Though I have previously tried different versions of jax and jaxlib (both "jax[cuda11_cudnn82]" and "jax[cuda11_cudnn805]").
My CUDA version is 11.4 and cudnn is 8.3.
Currently, the problem is quite crucial for use case of Vision Transformers and I am looking for some help or hints to solve this problem.
I attach to this post an Python file with running script which is isolated from data and use only basic packages as jax, flax and tf. It is made for 4 gpus, feel free to edit it at the end depending on your number of GPUs.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am working currently with Google Scenic project (https://github.com/google-research/scenic/tree/main/scenic/projects/vivit) which is based on JAX.
And one of the problem that was encountered is that during the evaluation step or train step the process just hangs randomly (can happen in 1 hour of training or in 30 hours). This hang happens somewhere inside the pmapped eval/train function. And it happens only if there are more than 1 GPU. Furthermore, during this hang the GPU utilization (I am currenlty using A100, V100 or P5000) is zero, but the CPU utilization is almost 100 on all cores. And it hangs without any error or Exception.
Basically, if the eval/train step function is vmapped, just jitted or pmapped with 1 GPU, the steps works perfectly.
I have tried different experiments to find out the issue, by isolating code from any dataset or other libraries apart from jax, flax, ml_collections, scipy and tf. But nothing appears to help. There is an obvious connection between the error and multi-gpu configuration.
My environment is based on last version of jax and jaxlib:
jax 0.3.13
jaxlib 0.3.10+cuda11.cudnn82
Though I have previously tried different versions of jax and jaxlib (both "jax[cuda11_cudnn82]" and "jax[cuda11_cudnn805]").
My CUDA version is 11.4 and cudnn is 8.3.
Currently, the problem is quite crucial for use case of Vision Transformers and I am looking for some help or hints to solve this problem.
I attach to this post an Python file with running script which is isolated from data and use only basic packages as jax, flax and tf. It is made for 4 gpus, feel free to edit it at the end depending on your number of GPUs.
eval_steps_simple.py.txt
Beta Was this translation helpful? Give feedback.
All reactions