Skip to content

undefined symbol: wp_cuda_graph_launch raised when using the Warp MJX backend #2865

@MasonMcGill

Description

@MasonMcGill

Intro

Hi,

I'm a computational neuroscientist at HHMI Janelia. I use MuJoCo to simulate fruit fly biomechanics.

My setup

CPU: Intel Xeon Platinum 8562Y+
GPU: Nvidia L4
Operating system: Rocky Linux 9.6
Nvidia driver version: 580.65.06
Python version: 3.11.13
Mujoco version: 3.3.6
Jax version: 0.7.2

What's happening? What did you expect?

When using the MJX Warp backend, calling a function compiled with jax.jit for the second time raises a shared library symbol lookup error:

== Calling my compiled function for the first time ==
Module mujoco.mjx.warp.ffi c862578 load on device 'cuda:0' took 5.50 ms  (cached)
Result: [0.        0.        2.9994113]
== Calling it again ==
Traceback (most recent call last):
  File "/groups/turaga/home/mcgillm/.cache/uv/environments-v2/bug-report-083d895a808135fc/lib/python3.11/site-packages/mujoco/mjx/third_party/warp/jax_experimental/ffi.py", line 578, in ffi_callback
    if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/groups/turaga/home/mcgillm/.local/share/uv/python/cpython-3.11.13-linux-x86_64-gnu/lib/python3.11/ctypes/__init__.py", line 389, in __getattr__
    func = self.__getitem__(name)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/groups/turaga/home/mcgillm/.local/share/uv/python/cpython-3.11.13-linux-x86_64-gnu/lib/python3.11/ctypes/__init__.py", line 394, in __getitem__
    func = self._FuncPtr((name_or_ordinal, self))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: /groups/turaga/home/mcgillm/.cache/uv/environments-v2/bug-report-083d895a808135fc/lib/python3.11/site-packages/warp/bin/warp.so: undefined symbol: wp_cuda_graph_launch

E0920 12:14:55.520473 2752322 pjrt_stream_executor_client.cc:3314] Execution of replica 0 failed: UNKNOWN: FFI callback error: AttributeError: /groups/turaga/home/mcgillm/.cache/uv/environments-v2/bug-report-083d895a808135fc/lib/python3.11/site-packages/warp/bin/warp.so: undefined symbol: wp_cuda_graph_launch
Traceback (most recent call last):
  File "/groups/turaga/home/mcgillm/fly-body-experiments/scripts/bug-report.py", line 47, in <module>
    main()
  File "/groups/turaga/home/mcgillm/fly-body-experiments/scripts/bug-report.py", line 23, in main
    print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: UNKNOWN: FFI callback error: AttributeError: /groups/turaga/home/mcgillm/.cache/uv/environments-v2/bug-report-083d895a808135fc/lib/python3.11/site-packages/warp/bin/warp.so: undefined symbol: wp_cuda_graph_launch

Using nm -D it looks like the library has the symbol cuda_graph_launch (without the wp_ prefix), but not the prefixed version the FFI is looking for.

It's totally possible that I'm using the new backend incorrectly, and it isn't intended to be used with jax.jit at all, but without compilation simulation is very slow (~0.5 seconds per step for a simple scene) so I'm not sure what else to try.

Steps for reproduction

  1. Install uv if you don't have it on your system already.
  2. Save the script below as bug-report.py.
  3. Run uv run bug-report.py.

Code required for reproduction

#!/usr/bin/env python3
# /// script
# requires-python = "==3.11.*"
# dependencies = [
#   "jax[cuda12]==0.7.2",
#   "mujoco==3.3.6",
#   "mujoco-mjx[warp]==3.3.6",
# ]
# ///
import jax
from mujoco import MjModel, mjx  # type: ignore

def main() -> None:
    host_model = MjModel.from_xml_string(model_spec_string())
    mjx_model = mjx.put_model(host_model, impl="warp")
    initial_sim_state = mjx.make_data(host_model, impl="warp")

    print("== Calling my compiled function for the first time ==")
    print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))

    print("== Calling it again ==")
    print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))

@jax.jit
def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array:
    sim_state = initial_sim_state
    for _ in range(5):
        sim_state = mjx.step(model, sim_state)
    return sim_state.qpos[:3]

def model_spec_string() -> str:
    return """<mujoco>
        <worldbody>
            <geom type="plane" size="10 10 0.1"/>
            <body pos="0 0 3">
                <geom type="sphere" size="0.2" rgba="0 1 0 1"/>
                <joint type="free"/>
            </body>
        </worldbody>
    </mujoco>"""

if __name__ == "__main__":
    main()

Confirmations

Metadata

Metadata

Assignees

No one assigned

    Labels

    MJXUsing JAX to run on GPUbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions