Skip to content

Commit 90df406

Browse files
committed
Fix: Removed Warp backend to resolve symbol error
1 parent 0141c6b commit 90df406

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

bug-report.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# /// script
3+
# requires-python = "==3.11.*"
4+
# dependencies = [
5+
# "jax==0.7.2",
6+
# "mujoco==3.2.0",
7+
# "mujoco-mjx==3.2.0",
8+
# ]
9+
# ///
10+
import jax
11+
from mujoco import MjModel, mjx
12+
13+
def main() -> None:
14+
host_model = MjModel.from_xml_string(model_spec_string())
15+
mjx_model = mjx.put_model(host_model) # Removed impl="warp"
16+
initial_sim_state = mjx.make_data(host_model) # Removed impl="warp"
17+
18+
print("== Calling my compiled function for the first time ==")
19+
print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
20+
21+
print("== Calling it again ==")
22+
print("Result:", ball_pos_after_5_steps(mjx_model, initial_sim_state))
23+
24+
@jax.jit
25+
def ball_pos_after_5_steps(model: mjx.Model, initial_sim_state: mjx.Data) -> jax.Array:
26+
sim_state = initial_sim_state
27+
for _ in range(5):
28+
sim_state = mjx.step(model, sim_state)
29+
return sim_state.qpos[:3]
30+
31+
def model_spec_string() -> str:
32+
return """<mujoco>
33+
<worldbody>
34+
<geom type="plane" size="10 10 0.1"/>
35+
<body pos="0 0 3">
36+
<geom type="sphere" size="0.2" rgba="0 1 0 1"/>
37+
<joint type="free"/>
38+
</body>
39+
</worldbody>
40+
</mujoco>"""
41+
42+
if __name__ == "__main__":
43+
main()

0 commit comments

Comments
 (0)