From 90df4066702a231f84620b862c9e4576da20f1a6 Mon Sep 17 00:00:00 2001 From: Harshdhall01 Date: Sat, 20 Sep 2025 23:48:25 +0530 Subject: [PATCH 1/2] Fix: Removed Warp backend to resolve symbol error --- bug-report.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 bug-report.py diff --git a/bug-report.py b/bug-report.py new file mode 100644 index 0000000000..6887bf1fe7 --- /dev/null +++ b/bug-report.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = "==3.11.*" +# dependencies = [ +# "jax==0.7.2", +# "mujoco==3.2.0", +# "mujoco-mjx==3.2.0", +# ] +# /// +import jax +from mujoco import MjModel, mjx + +def main() -> None: + host_model = MjModel.from_xml_string(model_spec_string()) + mjx_model = mjx.put_model(host_model) # Removed impl="warp" + initial_sim_state = mjx.make_data(host_model) # Removed impl="warp" + + print("== Calling my compiled function for the first time ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + + print("== Calling it again ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + +@jax.jit +def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array: + sim_state = initial_sim_state + for _ in range(5): + sim_state = mjx.step(model, sim_state) + return sim_state.qpos[:3] + +def model_spec_string() -> str: + return """ + + + + + + + + """ + +if __name__ == "__main__": + main() \ No newline at end of file From bb76189d35e29930ee8db5b66d5d6e612ac1d022 Mon Sep 17 00:00:00 2001 From: Harshdhall01 Date: Sun, 21 Sep 2025 09:08:22 +0000 Subject: [PATCH 2/2] Fix undefined symbol wp_cuda_graph_launch in Warp backend The Warp library exports cuda_graph_launch but MJX FFI code was looking for wp_cuda_graph_launch (with wp_ prefix). This caused AttributeError on second call to JAX-compiled functions. Fixes: undefined symbol error in Warp backend with JAX compilation --- .../third_party/warp/jax_experimental/ffi.py | 2 +- test-fix.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 test-fix.py diff --git a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py index 05165b491d..db3edff887 100644 --- a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py @@ -575,7 +575,7 @@ def ffi_callback(self, call_frame): raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}") graph.graph_exec = g - if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream): + if not wp.context.runtime.core.cuda_graph_launch(graph.graph_exec, cuda_stream): raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}") # early out diff --git a/test-fix.py b/test-fix.py new file mode 100644 index 0000000000..21b133a817 --- /dev/null +++ b/test-fix.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import jax +from mujoco import MjModel, mjx + +def main() -> None: + host_model = MjModel.from_xml_string(model_spec_string()) + mjx_model = mjx.put_model(host_model, impl="warp") + initial_sim_state = mjx.make_data(host_model, impl="warp") + + print("== Calling my compiled function for the first time ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + + print("== Calling it again ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + +@jax.jit +def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array: + sim_state = initial_sim_state + for _ in range(5): + sim_state = mjx.step(model, sim_state) + return sim_state.qpos[:3] + +def model_spec_string() -> str: + return """ + + + + + + + + """ + +if __name__ == "__main__": + main()