Replies: 4 comments 7 replies
-
I encountered the exact same issue: pmaped step function works fine on single GPU, but when scaling to multi GPU, the step function might hang randomly with 0% GPU usage. There is a high variance in the frequency but I would say it happens once every 100k to 2M training steps (ie. pmaped function calls). It is not a RAM or device memory issue. There is no error thrown, the training is just "paused" and you have to manually terminate it. This seems not to be dependent on the dataloading pipeline, as when the hanging happens, the batch is well loaded and available in the device memory. The issue is also not hardware dependant, as it happens both on A100 or P5000 GPUs. I encountered the issue both in old (1.71) and recent (3.10) JAX versions. I experienced this both with haiku and flax-based training frameworks. Finally, early tests seem to indicate that the issue does not depend on the CuDNN version: 8.1.2 or 8.4 |
Beta Was this translation helpful? Give feedback.
-
Hello, I thing I have quite the same issue. I am working currently with Google Scenic project (https://github.com/google-research/scenic/tree/main/scenic/projects/vivit) which is based on JAX. And one of the problem that was encountered is that during the evaluation step or train step the process just hangs randomly (can happen in 1 hour of training or in 30 hours). This hang happens somewhere inside the pmapped eval/train function. And it happens only if there are more than 1 GPU. Furthermore, during this hang the GPU utilization (I am currenlty using A100, V100 or P5000) is zero, but the CPU utilization is almost 100 on all cores. And it hangs without any error or Exception. Basically, if the eval/train step function is vmapped, just jitted or pmapped with 1 GPU, the steps works perfectly. I have tried different experiments to find out the issue, by isolating code from any dataset or other libraries apart from jax, flax, ml_collections, scipy and tf. But nothing appears to help. There is an obvious connection between the error and multi-gpu configuration. My environment is based on last version of jax and jaxlib: jax 0.3.13 Though I have previously tried different versions of jax and jaxlib (both "jax[cuda11_cudnn82]" and "jax[cuda11_cudnn805]"). My CUDA version is 11.4 and cudnn is 8.3. Currently, the problem is quite crucial for use case of Vision Transformers and I am looking for some help or hints to solve this problem. I attach to this post an Python file with running script which is isolated from data and use only basic packages as jax, flax and tf. It is made for 4 gpus, feel free to edit it at the end depending on your number of GPUs. |
Beta Was this translation helpful? Give feedback.
-
I created an issue for this since this has grown beyond simply being a discussion (#10969). |
Beta Was this translation helpful? Give feedback.
-
We're seeing the same issue on almost exactly the same setup 8x A100 with 0% GPU / 100% CPU , Jax / Flax versions the same. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
We are facing a problem where a training and validation code based on jax/flax hangs randomly on a multi-gpu host.
Using a single GPU is working correctly but once we add multi-gpu support it hangs in an unpredictable way.
The GPUs usage is at 0% for all GPU but the CPU is used.
What could be the problem ?
Beta Was this translation helpful? Give feedback.
All reactions