diff --git a/mjx/mujoco/mjx/_src/solver.py b/mjx/mujoco/mjx/_src/solver.py index f57dbd1cce..34dbfec683 100644 --- a/mjx/mujoco/mjx/_src/solver.py +++ b/mjx/mujoco/mjx/_src/solver.py @@ -598,6 +598,19 @@ def body(ctx: Context) -> Context: ctx = Context.create(m, d) if m.opt.iterations == 1: ctx = body(ctx) + + elif hasattr(m.opt, 'fixed_iterations') and m.opt.fixed_iterations is not None: + # Use fori_loop with static bounds for gradient compatibility + def fixed_body_fn(i, ctx): + return body(ctx) + + # This loop runs exactly fixed_iterations times + ctx = jax.lax.fori_loop( + 0, m.opt.fixed_iterations, + fixed_body_fn, + ctx + ) + else: ctx = jax.lax.while_loop(cond, body, ctx) diff --git a/mjx/mujoco/mjx/_src/solver_test.py b/mjx/mujoco/mjx/_src/solver_test.py index 0b1b3b5a07..3c89ab5e56 100644 --- a/mjx/mujoco/mjx/_src/solver_test.py +++ b/mjx/mujoco/mjx/_src/solver_test.py @@ -25,6 +25,7 @@ import numpy as np + # tolerance for difference between MuJoCo and MJX solver calculations, # mostly due to float precision _TOLERANCE = 5e-3 diff --git a/mjx/mujoco/mjx/_src/types.py b/mjx/mujoco/mjx/_src/types.py index 8d9ade1849..762c30969f 100644 --- a/mjx/mujoco/mjx/_src/types.py +++ b/mjx/mujoco/mjx/_src/types.py @@ -18,6 +18,7 @@ from typing import Tuple, Union import warnings +from typing import Optional import jax import mujoco from mujoco.mjx._src.dataclasses import PyTreeNode # pylint: disable=g-importing-member @@ -483,6 +484,7 @@ class OptionJAX(PyTreeNode): disableactuator: int sdf_initpoints: int has_fluid_params: bool + fixed_iterations: Optional[int] = None class OptionC(PyTreeNode):