Skip to content

NaN gradients of cost wrt actions on cart pole balance (full standalone example with specific state and controls) #239

@bayerj

Description

@bayerj

Hey,

I am currently evaluating mujoco playground for model-predictive control and wanted to rely on automatic differentiation. During one optimisation, I get NaN gradients on the cart pole. As far as I can tell, this does not happen during some special states like at collisions or bifurcation, but at a fairly ordinary place in state space. My expectation would be that mujoco playground/mjx can handle these gracefully.

Below is a standalone example that I am running on a Mac M3 (not a GPU!) and which NaNs immediately. All it does is taking the gradient of the cart pole cost wrt for a fixed initial state wrt to a control trajectory. No MLP-Policy involved, no training, just feeding a control sequence in a diff'ing wrt it.

I did some digging, which means I replaced scans with explicit for loops, but I stopped going further down than tracing it to env.step. Sadly, Jax does not give us good information on NaNs if it is happening within a scan/fori_loop. JAX_ENABLE_X64=True fixes this, but that is not really satisfying.

How can I debug this effectively? I would like to know where the NaNs originate from, but that is hard. Maybe there are some knobs to tune which make it go away -- like solver dts or similar.

The script is below, it can be run as an executable as soon as it has been marked as executable, e.g. chmod +x. It will also save the rollout as a gif.

#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
#   "playground", "moviepy"
# ]
# ///

import sys

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree as jt
import mujoco_playground as mp
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from mujoco import mjx

env = mp.registry.load("CartpoleBalance")
mjx_model = env.mjx_model

qpos = jnp.array([0.06638916, 0.02480028], dtype="float")
qvel = jnp.array([0.01101728, -0.001906830], dtype="float")
# fmt: off
plan = jnp.array(
    [
        0.4406856, -0.15623903, 0.15121995, -0.6513512, 0.38101465,
        -0.8696436, 0.6175662, 0.6062752, 0.43547752, 0.24218927,
        0.89787376, 0.42185387, 0.788631, 0.42122164, 0.8994053,
        0.9807481, -0.53927225, 0.00305146, -0.32759377, 0.629019,
    ],
    dtype="float",
)[:, jnp.newaxis]
# fmt: on

mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
mjx_data = mjx.forward(mjx_model, mjx_data)
state0 = env.reset(jr.PRNGKey(32))
state0 = state0.replace(data=mjx_data)

jit_env_step = jax.jit(env.step)


def tree_stack(trees):
    """Takes a list of trees and stacks every corresponding leaf.
    For example, given two trees ((a, b), c) and ((a', b'), c'), returns
    ((stack(a, a'), stack(b, b')), stack(c, c')).
    Useful for turning a list of objects into something you can feed to a
    vmapped function.
    """
    leaves_list = []
    treedef_list = []
    for tree in trees:
        leaves, treedef = jt.flatten(tree)
        leaves_list.append(leaves)
        treedef_list.append(treedef)

    grouped_leaves = zip(*leaves_list)
    result_leaves = [jnp.stack(k) for k in grouped_leaves]
    return treedef_list[0].unflatten(result_leaves)


def cost_fn(plan, initial_state):
    """Simple cost function using mjx.forward()."""
    states = [initial_state]
    costs = []
    for i in range(plan.shape[0]):
        # Emulate inner steps.
        for j in range(4):
            states.append(jit_env_step(states[-1], plan[i]))
    state_arr = tree_stack([i for i in states])
    costs = -jnp.stack([i.reward for i in states[::4]])
    avg_cost = costs.mean()
    return avg_cost, state_arr


grad_fn = jax.grad(cost_fn, has_aux=True)
grad, states = grad_fn(plan, state0)
print(grad)

state_list = [
    jt.map(lambda s: s[i], states) for i in range(states.reward.shape[0])
]
frames = env.render(state_list)
clip = ImageSequenceClip(list(frames), fps=10)
clip.write_gif("rollout.gif", fps=10, logger=None)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions