File tree Expand file tree Collapse file tree 1 file changed +43
-0
lines changed Expand file tree Collapse file tree 1 file changed +43
-0
lines changed Original file line number Diff line number Diff line change
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = "==3.11.*"
4
+ # dependencies = [
5
+ # "jax==0.7.2",
6
+ # "mujoco==3.2.0",
7
+ # "mujoco-mjx==3.2.0",
8
+ # ]
9
+ # ///
10
+ import jax
11
+ from mujoco import MjModel , mjx
12
+
13
+ def main () -> None :
14
+ host_model = MjModel .from_xml_string (model_spec_string ())
15
+ mjx_model = mjx .put_model (host_model ) # Removed impl="warp"
16
+ initial_sim_state = mjx .make_data (host_model ) # Removed impl="warp"
17
+
18
+ print ("== Calling my compiled function for the first time ==" )
19
+ print ("Result:" , ball_pos_after_5_steps (mjx_model , initial_sim_state ))
20
+
21
+ print ("== Calling it again ==" )
22
+ print ("Result:" , ball_pos_after_5_steps (mjx_model , initial_sim_state ))
23
+
24
+ @jax .jit
25
+ def ball_pos_after_5_steps (model : mjx .Model , initial_sim_state : mjx .Data ) -> jax .Array :
26
+ sim_state = initial_sim_state
27
+ for _ in range (5 ):
28
+ sim_state = mjx .step (model , sim_state )
29
+ return sim_state .qpos [:3 ]
30
+
31
+ def model_spec_string () -> str :
32
+ return """<mujoco>
33
+ <worldbody>
34
+ <geom type="plane" size="10 10 0.1"/>
35
+ <body pos="0 0 3">
36
+ <geom type="sphere" size="0.2" rgba="0 1 0 1"/>
37
+ <joint type="free"/>
38
+ </body>
39
+ </worldbody>
40
+ </mujoco>"""
41
+
42
+ if __name__ == "__main__" :
43
+ main ()
You can’t perform that action at this time.
0 commit comments