Skip to content

MJX: Warp FFI sometimes deadlocks when donate_argnums is present in pmap #2980

@hartikainen

Description

@hartikainen

Intro

Hey,

The following bug report is a bit limited in its details and I'm aware that the Jax-Warp backend is still very experimental. But I still wanted to report this issue to see if anyone has experienced something similar, and hopefully also save some time from others.

My setup

mujoco==3.4.0, mujoco-mjx==3.4.0, warp-lang==1.10.1 on linux_x86_64 with 8xA100 GPUs.

What's happening? What did you expect?

We use Brax PPO to train RL policies. We've recently noticed that jax.pmap of training_epoch in ppo/train.py hangs when donate_argnums=(0,1). What's annoying about this hanging is that it only happens after O(10e6) training steps.

The reason I think this is a MJX-Warp FFI issue is that the training works fine with impl=jax and only fails when impl=warp. Additionally, I've recently observed similar hanging in MJX-Warp FFI in a completely independent setting. In that setting, I was trying to access Warp memory outside of the allocated range, which deadlocked silently in a very similar way.

A simple workaround for those who might experience similar deadlocks/hangs is to remove the donate_argnums variable from jax.pmap.

Steps for reproduction

As mentioned above, I unfortunately don't yet have a good reproduction for this. Our mujoco model is proprietary and the training setup is a bit customized, which is why I haven't been able to reproduce this with open-source models or vanilla Brax.

Minimal model for reproduction

No response

Code required for reproduction

No response

Confirmations

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions