Skip to content

Commit db66919

Browse files
committed
Fix undefined symbol wp_cuda_graph_launch in Warp backend
The Warp library exports cuda_graph_launch but MJX FFI code was looking for wp_cuda_graph_launch (with wp_ prefix). This caused AttributeError on second call to JAX-compiled functions. Fixes: undefined symbol error in Warp backend with JAX compilation
1 parent 90df406 commit db66919

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

mjx/mujoco/mjx/third_party/warp/jax_experimental/ffi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def ffi_callback(self, call_frame):
575575
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
576576
graph.graph_exec = g
577577

578-
if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
578+
if not wp.context.runtime.core.cuda_graph_launch(graph.graph_exec, cuda_stream):
579579
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
580580

581581
# early out

test-fix.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python3
2+
import jax
3+
from mujoco import MjModel, mjx
4+
5+
def main() -> None:
6+
host_model = MjModel.from_xml_string(model_spec_string())
7+
mjx_model = mjx.put_model(host_model, impl="warp")
8+
initial_sim_state = mjx.make_data(host_model, impl="warp")
9+
10+
print("== Calling my compiled function for the first time ==")
11+
print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
12+
13+
print("== Calling it again ==")
14+
print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
15+
16+
@jax.jit
17+
def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array:
18+
sim_state = initial_sim_state
19+
for _ in range(5):
20+
sim_state = mjx.step(model, sim_state)
21+
return sim_state.qpos[:3]
22+
23+
def model_spec_string() -> str:
24+
return """<mujoco>
25+
<worldbody>
26+
<geom type="plane" size="10 10 0.1"/>
27+
<body pos="0 0 3">
28+
<geom type="sphere" size="0.2" rgba="0 1 0 1"/>
29+
<joint type="free"/>
30+
</body>
31+
</worldbody>
32+
</mujoco>"""
33+
34+
if __name__ == "__main__":
35+
main()

0 commit comments

Comments
 (0)