-
Notifications
You must be signed in to change notification settings - Fork 264
Description
Hey,
I am currently evaluating mujoco playground for model-predictive control and wanted to rely on automatic differentiation. During one optimisation, I get NaN gradients on the cart pole. As far as I can tell, this does not happen during some special states like at collisions or bifurcation, but at a fairly ordinary place in state space. My expectation would be that mujoco playground/mjx can handle these gracefully.
Below is a standalone example that I am running on a Mac M3 (not a GPU!) and which NaNs immediately. All it does is taking the gradient of the cart pole cost wrt for a fixed initial state wrt to a control trajectory. No MLP-Policy involved, no training, just feeding a control sequence in a diff'ing wrt it.
I did some digging, which means I replaced scans with explicit for loops, but I stopped going further down than tracing it to env.step. Sadly, Jax does not give us good information on NaNs if it is happening within a scan/fori_loop. JAX_ENABLE_X64=True fixes this, but that is not really satisfying.
How can I debug this effectively? I would like to know where the NaNs originate from, but that is hard. Maybe there are some knobs to tune which make it go away -- like solver dts or similar.
The script is below, it can be run as an executable as soon as it has been marked as executable, e.g. chmod +x. It will also save the rollout as a gif.
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "playground", "moviepy"
# ]
# ///
import sys
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree as jt
import mujoco_playground as mp
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from mujoco import mjx
env = mp.registry.load("CartpoleBalance")
mjx_model = env.mjx_model
qpos = jnp.array([0.06638916, 0.02480028], dtype="float")
qvel = jnp.array([0.01101728, -0.001906830], dtype="float")
# fmt: off
plan = jnp.array(
[
0.4406856, -0.15623903, 0.15121995, -0.6513512, 0.38101465,
-0.8696436, 0.6175662, 0.6062752, 0.43547752, 0.24218927,
0.89787376, 0.42185387, 0.788631, 0.42122164, 0.8994053,
0.9807481, -0.53927225, 0.00305146, -0.32759377, 0.629019,
],
dtype="float",
)[:, jnp.newaxis]
# fmt: on
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
mjx_data = mjx.forward(mjx_model, mjx_data)
state0 = env.reset(jr.PRNGKey(32))
state0 = state0.replace(data=mjx_data)
jit_env_step = jax.jit(env.step)
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
"""
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = jt.flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)
grouped_leaves = zip(*leaves_list)
result_leaves = [jnp.stack(k) for k in grouped_leaves]
return treedef_list[0].unflatten(result_leaves)
def cost_fn(plan, initial_state):
"""Simple cost function using mjx.forward()."""
states = [initial_state]
costs = []
for i in range(plan.shape[0]):
# Emulate inner steps.
for j in range(4):
states.append(jit_env_step(states[-1], plan[i]))
state_arr = tree_stack([i for i in states])
costs = -jnp.stack([i.reward for i in states[::4]])
avg_cost = costs.mean()
return avg_cost, state_arr
grad_fn = jax.grad(cost_fn, has_aux=True)
grad, states = grad_fn(plan, state0)
print(grad)
state_list = [
jt.map(lambda s: s[i], states) for i in range(states.reward.shape[0])
]
frames = env.render(state_list)
clip = ImageSequenceClip(list(frames), fps=10)
clip.write_gif("rollout.gif", fps=10, logger=None)