Skip to content

warp backend causes nans in simulation #248

@aaprasad

Description

@aaprasad

I've been using mujoco-playground to train a C.elegans biomechanical model to do imitation learning using the mjx backend and it has worked fine. When I try switching to warp with identitical configurations otherwise, the physics becomes unstable and goes to nan. I'm not sure what could be causing this. These are the library versions I'm working with:

jax                                     0.7.2               pypi_0                 pypi
jax-cuda12-pjrt                         0.7.2               pypi_0                 pypi
jax-cuda12-plugin                       0.7.2               pypi_0                 pypi
jaxlib                                  0.7.2               pypi_0                 pypi
jaxopt                                  0.8.5               pypi_0                 pypi
mujoco                                  3.4.0            pypi_0                 pypi
mujoco-mjx                              3.4.0            pypi_0                 pypi
mujoco-warp                             0.0.1            pypi_0                 pypi
playground                              0.0.5            pypi_0                 pypi
warp-lang                               1.10.0           pypi_0                 pypi

I'm on ubuntu and using a 5090 gpu with cuda 12.8. I've attached a minimal example to reproduce the bug by just stepping through a basic environment:

warp_debug.zip

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions