diff --git a/bug-report.py b/bug-report.py new file mode 100644 index 0000000000..6887bf1fe7 --- /dev/null +++ b/bug-report.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = "==3.11.*" +# dependencies = [ +# "jax==0.7.2", +# "mujoco==3.2.0", +# "mujoco-mjx==3.2.0", +# ] +# /// +import jax +from mujoco import MjModel, mjx + +def main() -> None: + host_model = MjModel.from_xml_string(model_spec_string()) + mjx_model = mjx.put_model(host_model) # Removed impl="warp" + initial_sim_state = mjx.make_data(host_model) # Removed impl="warp" + + print("== Calling my compiled function for the first time ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + + print("== Calling it again ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + +@jax.jit +def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array: + sim_state = initial_sim_state + for _ in range(5): + sim_state = mjx.step(model, sim_state) + return sim_state.qpos[:3] + +def model_spec_string() -> str: + return """ + + + + + + + + """ + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py index 05165b491d..db3edff887 100644 --- a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py @@ -575,7 +575,7 @@ def ffi_callback(self, call_frame): raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}") graph.graph_exec = g - if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream): + if not wp.context.runtime.core.cuda_graph_launch(graph.graph_exec, cuda_stream): raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}") # early out diff --git a/test-fix.py b/test-fix.py new file mode 100644 index 0000000000..21b133a817 --- /dev/null +++ b/test-fix.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import jax +from mujoco import MjModel, mjx + +def main() -> None: + host_model = MjModel.from_xml_string(model_spec_string()) + mjx_model = mjx.put_model(host_model, impl="warp") + initial_sim_state = mjx.make_data(host_model, impl="warp") + + print("== Calling my compiled function for the first time ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + + print("== Calling it again ==") + print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state)) + +@jax.jit +def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array: + sim_state = initial_sim_state + for _ in range(5): + sim_state = mjx.step(model, sim_state) + return sim_state.qpos[:3] + +def model_spec_string() -> str: + return """ + + + + + + + + """ + +if __name__ == "__main__": + main()