Skip to content

[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients #2259

@EGalahad

Description

@EGalahad

The feature, motivation and pitch

Problem

The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.

Error encountered with jax.jit compiled grad function:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.

Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.

Proposed Solution

Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

Labels

MJXUsing JAX to run on GPUenhancementNew feature or requestgood first issueGood for newcomers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions