Having trouble getting jacobians out of a simple Mujoco MJX simulation. #2650
-
IntroHi! My name is Tristan I am a PhD student at Texas Tech University. I use MuJoCo for my research on information theory and optimal control theory. My setupmujoco version: '3.3.2' ![]() My questionI need to extract jacobians of a dynamical system defined by the mujoco simulator. Dynamics (mujoco simulation): Jacobians: I am finding this difficult to do. While forward mode autodiff works, reverse mode does not. I need to understand why reverse mode fails. When running my code I get the following error: The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. The above exception was the direct cause of the following exception: Traceback (most recent call last): Minimal model and/or code that explain my questionIf you encountered the issue in a complex model, please simplify it as much as possible (while still reproducing the issue). Model: minimal XML<mujoco>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.6 0.7 0.9" rgb2="0.3 0.3 0.4" width="512" height="3072"/>
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance="0.2"/>
</asset>
<worldbody>
<light name="global_light" directional="true" diffuse="1 1 1" specular="0.3 0.3 0.3" castshadow="false"/>
<light name="fill_light_0" pos="-5 0 10" dir="0 0 -1" directional="false" diffuse="1 1 1" specular="0.2 0.2 0.2"/>
<light name="fill_light_1" pos="0 0 10" dir="0 0 -1" directional="false" diffuse="1 1 1" specular="0.2 0.2 0.2"/>
<light name="fill_light_2" pos="5 0 10" dir="0 0 -1" directional="false" diffuse="1 1 1" specular="0.2 0.2 0.2"/>
<geom name="floor" pos="0 0 0" size="0 0 0.05" type="plane" material="groundplane"/>
<body name="sphere_0" pos="-0.5 0.0 0.15">
<!-- position joints -->
<joint name="sphere_0_x_joint" type="slide" axis="1 0 0"/>
<joint name="sphere_0_z_joint" type="slide" axis="0 0 1"/>
<!-- angle joints -->
<!-- <joint name="sphere_0_i_joint" type="ball" axis="1 0 0"/> -->
<geom size=".15" mass="1" type="sphere"/>
<site name="site_sphere_0" pos="0 0 0"/>
</body>
<body name="sphere_1" pos="0.5 0.0 0.15">
<joint name="sphere_1_x_joint" type="slide" axis="1 0 0"/>
<joint name="sphere_1_z_joint" type="slide" axis="0 0 1"/>
<!-- angle joints -->
<!-- <joint name="sphere_1_i_joint" type="ball" axis="1 0 0"/> -->
<geom size=".15" mass="1" type="sphere"/>
<site name="site_sphere_1" pos="0 0 0"/>
</body>
<body name="sphere_2" pos="-0.5 0.0 1.15">
<joint name="sphere_2_x_joint" type="slide" axis="1 0 0"/>
<joint name="sphere_2_z_joint" type="slide" axis="0 0 1"/>
<!-- angle joints -->
<!-- <joint name="sphere_2_i_joint" type="ball" axis="1 0 0"/> -->
<geom size=".15" mass="1" type="sphere"/>
<site name="site_sphere_2" pos="0 0 0"/>
</body>
<body name="sphere_3" pos="0.5 0.0 1.15">
<joint name="sphere_3_x_joint" type="slide" axis="1 0 0"/>
<joint name="sphere_3_z_joint" type="slide" axis="0 0 1"/>
<!-- angle joints -->
<!-- <joint name="sphere_3_i_joint" type="ball" axis="1 0 0"/> -->
<geom size=".15" mass="1" type="sphere"/>
<site name="site_sphere_3" pos="0 0 0"/>
</body>
</worldbody>
<tendon>
<spatial name="spring_0_1" stiffness="20" damping="1" springlength="1.0">
<site site="site_sphere_0"/>
<site site="site_sphere_1"/>
</spatial>
<spatial name="spring_0_2" stiffness="30" damping="1" springlength="1.0">
<site site="site_sphere_0"/>
<site site="site_sphere_2"/>
</spatial>
<spatial name="spring_2_3" stiffness="30" damping="1" springlength="1.0">
<site site="site_sphere_2"/>
<site site="site_sphere_3"/>
</spatial>
<spatial name="spring_3_1" stiffness="30" damping="1" springlength="1.0">
<site site="site_sphere_3"/>
<site site="site_sphere_1"/>
</spatial>
<spatial name="spring_3_0" stiffness="30" damping="1" springlength="1.4142">
<site site="site_sphere_3"/>
<site site="site_sphere_0"/>
</spatial>
<spatial name="spring_2_0" stiffness="30" damping="1" springlength="1.4142">
<site site="site_sphere_2"/>
<site site="site_sphere_1"/>
</spatial>
</tendon>
<actuator>
<motor joint="sphere_0_x_joint" ctrllimited="true" ctrlrange="-10 10"/>
</actuator>
</mujoco>
Code: from pathlib import Path
import mujoco
import mujoco.viewer
from mujoco import mjx
from mujoco.mjx import Model, Data
import jax
from jax import numpy as jnp
from jax import Array
jax.config.update("jax_traceback_filtering", "off")
if __name__ == '__main__':
model = mujoco.MjModel.from_xml_path('xml_files/custom/blob.xml')
data = mujoco.MjData(model)
## converting to mjx
mjx_model = mjx.put_model(model)
mjx_data = mjx.make_data(mjx_model)
## define state and control input
nq, nv, nu = mjx_model.nq, mjx_model.nv, mjx_model.nu
xt = jnp.zeros(nq + nv)
ut = jnp.zeros(nu)
## jit the jax step function
mjx_step = jax.jit(mjx.step)
def f(state: Array, action: Array, m: Model, d: Data):
qpos = state[:nq]
qvel = state[nv:]
d = d.replace(qpos = qpos, qvel = qvel, ctrl = action)
d = mjx_step(m, d)
return jnp.concatenate([d.qpos, d.qvel])
# J = jax.jacfwd(f, argnums = 0) ## works
J = jax.jacrev(f, argnums = 0) ## fails
A = J(xt, ut, mjx_model, mjx_data)
print(A) Confirmations
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Did you try setting |
Beta Was this translation helpful? Give feedback.
I have done that and it works. But it seems the best solution for me is to just use jacfwd.