-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Labels
MJXUsing JAX to run on GPUUsing JAX to run on GPUenhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers
Description
The feature, motivation and pitch
Problem
The solver's jax.lax.while_loop
implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.
Error encountered with jax.jit
compiled grad function:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.
Current workaround of using opt.iteration=1
leads to potentially inaccurate simulation and gradients.
Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan
or lax.fori_loop
with static bounds.
Alternatives
No response
Additional context
No response
btaba, varshneydevansh and jeh15
Metadata
Metadata
Labels
MJXUsing JAX to run on GPUUsing JAX to run on GPUenhancementNew feature or requestNew feature or requestgood first issueGood for newcomersGood for newcomers