-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
Intro
Hey,
The following bug report is a bit limited in its details and I'm aware that the Jax-Warp backend is still very experimental. But I still wanted to report this issue to see if anyone has experienced something similar, and hopefully also save some time from others.
My setup
mujoco==3.4.0, mujoco-mjx==3.4.0, warp-lang==1.10.1 on linux_x86_64 with 8xA100 GPUs.
What's happening? What did you expect?
We use Brax PPO to train RL policies. We've recently noticed that jax.pmap of training_epoch in ppo/train.py hangs when donate_argnums=(0,1). What's annoying about this hanging is that it only happens after O(10e6) training steps.
The reason I think this is a MJX-Warp FFI issue is that the training works fine with impl=jax and only fails when impl=warp. Additionally, I've recently observed similar hanging in MJX-Warp FFI in a completely independent setting. In that setting, I was trying to access Warp memory outside of the allocated range, which deadlocked silently in a very similar way.
A simple workaround for those who might experience similar deadlocks/hangs is to remove the donate_argnums variable from jax.pmap.
Steps for reproduction
As mentioned above, I unfortunately don't yet have a good reproduction for this. Our mujoco model is proprietary and the training setup is a bit customized, which is why I haven't been able to reproduce this with open-source models or vanilla Brax.
Minimal model for reproduction
No response
Code required for reproduction
No response
Confirmations
- I searched the latest documentation thoroughly before posting.
- I searched previous Issues and Discussions, I am certain this has not been raised before.