diff --git a/bug-report.py b/bug-report.py
new file mode 100644
index 0000000000..6887bf1fe7
--- /dev/null
+++ b/bug-report.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = "==3.11.*"
+# dependencies = [
+# "jax==0.7.2",
+# "mujoco==3.2.0",
+# "mujoco-mjx==3.2.0",
+# ]
+# ///
+import jax
+from mujoco import MjModel, mjx
+
+def main() -> None:
+ host_model = MjModel.from_xml_string(model_spec_string())
+ mjx_model = mjx.put_model(host_model) # Removed impl="warp"
+ initial_sim_state = mjx.make_data(host_model) # Removed impl="warp"
+
+ print("== Calling my compiled function for the first time ==")
+ print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
+
+ print("== Calling it again ==")
+ print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
+
+@jax.jit
+def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array:
+ sim_state = initial_sim_state
+ for _ in range(5):
+ sim_state = mjx.step(model, sim_state)
+ return sim_state.qpos[:3]
+
+def model_spec_string() -> str:
+ return """
+
+
+
+
+
+
+
+ """
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py
index 05165b491d..db3edff887 100644
--- a/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py
+++ b/mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py
@@ -575,7 +575,7 @@ def ffi_callback(self, call_frame):
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
graph.graph_exec = g
- if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
+ if not wp.context.runtime.core.cuda_graph_launch(graph.graph_exec, cuda_stream):
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
# early out
diff --git a/test-fix.py b/test-fix.py
new file mode 100644
index 0000000000..21b133a817
--- /dev/null
+++ b/test-fix.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+import jax
+from mujoco import MjModel, mjx
+
+def main() -> None:
+ host_model = MjModel.from_xml_string(model_spec_string())
+ mjx_model = mjx.put_model(host_model, impl="warp")
+ initial_sim_state = mjx.make_data(host_model, impl="warp")
+
+ print("== Calling my compiled function for the first time ==")
+ print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
+
+ print("== Calling it again ==")
+ print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
+
+@jax.jit
+def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array:
+ sim_state = initial_sim_state
+ for _ in range(5):
+ sim_state = mjx.step(model, sim_state)
+ return sim_state.qpos[:3]
+
+def model_spec_string() -> str:
+ return """
+
+
+
+
+
+
+
+ """
+
+if __name__ == "__main__":
+ main()