diff --git a/src/mjlab/envs/manager_based_rl_env.py b/src/mjlab/envs/manager_based_rl_env.py index 133c74ae9..198f525aa 100644 --- a/src/mjlab/envs/manager_based_rl_env.py +++ b/src/mjlab/envs/manager_based_rl_env.py @@ -157,10 +157,13 @@ def __init__( cfg: ManagerBasedRlEnvCfg, device: str, render_mode: str | None = None, + *, + num_steps_per_env: int | None = None, **kwargs, ) -> None: # Initialize base environment state. self.cfg = cfg + self._num_steps_per_env = num_steps_per_env if self.cfg.seed is not None: self.cfg.seed = self.seed(self.cfg.seed) self._sim_step_counter = 0 @@ -299,7 +302,9 @@ def load_managers(self) -> None: ) print_info(f"[INFO] {self.reward_manager}") if len(self.cfg.curriculum) > 0: - self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self) + self.curriculum_manager = CurriculumManager( + self.cfg.curriculum, self, num_steps_per_env=self._num_steps_per_env + ) else: self.curriculum_manager = NullCurriculumManager() print_info(f"[INFO] {self.curriculum_manager}") diff --git a/src/mjlab/managers/curriculum_manager.py b/src/mjlab/managers/curriculum_manager.py index d34b25aa6..236686d70 100644 --- a/src/mjlab/managers/curriculum_manager.py +++ b/src/mjlab/managers/curriculum_manager.py @@ -26,6 +26,32 @@ class CurriculumTermCfg(ManagerTermBaseCfg): pass +def resolve_curriculum_iterations( + curriculum: dict[str, CurriculumTermCfg], + num_steps_per_env: int, +) -> None: + """Convert ``"iteration"`` keys to ``"step"`` keys in curriculum stages. + + This allows curriculum configs to express thresholds in training + iterations (more intuitive) rather than raw environment steps. + The conversion is ``step = iteration * num_steps_per_env``. + + Modifies *curriculum* in-place. Raises :class:`ValueError` if a stage + dict contains both ``"iteration"`` and ``"step"`` keys. + """ + for term_cfg in curriculum.values(): + for value in term_cfg.params.values(): + if not isinstance(value, list): + continue + for item in value: + if not isinstance(item, dict): + continue + if "iteration" in item and "step" in item: + raise ValueError(f"Curriculum stage has both 'iteration' and 'step': {item}") + if "iteration" in item: + item["step"] = item.pop("iteration") * num_steps_per_env + + class CurriculumManager(ManagerBase): """Manages curriculum learning for the environment. @@ -36,12 +62,23 @@ class CurriculumManager(ManagerBase): _env: ManagerBasedRlEnv - def __init__(self, cfg: dict[str, CurriculumTermCfg], env: ManagerBasedRlEnv): + def __init__( + self, + cfg: dict[str, CurriculumTermCfg], + env: ManagerBasedRlEnv, + *, + num_steps_per_env: int | None = None, + ): self._term_names: list[str] = list() self._term_cfgs: list[CurriculumTermCfg] = list() self._class_term_cfgs: list[CurriculumTermCfg] = list() + # Work on a deep copy so modifications are local to this manager. self.cfg = deepcopy(cfg) + + if num_steps_per_env is not None: + resolve_curriculum_iterations(self.cfg, num_steps_per_env) + super().__init__(env) self._curriculum_state = dict() diff --git a/src/mjlab/scripts/play.py b/src/mjlab/scripts/play.py index 3fd8602d9..606027952 100644 --- a/src/mjlab/scripts/play.py +++ b/src/mjlab/scripts/play.py @@ -154,7 +154,12 @@ def run_play(task_id: str, cfg: PlayConfig): print( "[WARN] Video recording with dummy agents is disabled (no checkpoint/log_dir)." ) - env = ManagerBasedRlEnv(cfg=env_cfg, device=device, render_mode=render_mode) + env = ManagerBasedRlEnv( + cfg=env_cfg, + device=device, + render_mode=render_mode, + num_steps_per_env=agent_cfg.num_steps_per_env, + ) if TRAINED_MODE and cfg.video: print("[INFO] Recording videos during play") diff --git a/src/mjlab/scripts/train.py b/src/mjlab/scripts/train.py index 22079dfbf..726458acd 100644 --- a/src/mjlab/scripts/train.py +++ b/src/mjlab/scripts/train.py @@ -106,7 +106,10 @@ def run_train(task_id: str, cfg: TrainConfig, log_dir: Path) -> None: print(f"[INFO] Logging experiment in directory: {log_dir}") env = ManagerBasedRlEnv( - cfg=cfg.env, device=device, render_mode="rgb_array" if cfg.video else None + cfg=cfg.env, + device=device, + render_mode="rgb_array" if cfg.video else None, + num_steps_per_env=cfg.agent.num_steps_per_env, ) log_root_path = log_dir.parent # Go up from specific run dir to experiment dir. diff --git a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py index 533805172..bc743f781 100644 --- a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py +++ b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py @@ -202,9 +202,9 @@ def make_lift_cube_env_cfg() -> ManagerBasedRlEnvCfg: params={ "reward_name": "joint_vel_hinge", "weight_stages": [ - {"step": 0, "weight": -0.01}, - {"step": 500 * 24, "weight": -0.1}, - {"step": 1000 * 24, "weight": -1.0}, + {"iteration": 0, "weight": -0.01}, + {"iteration": 500, "weight": -0.1}, + {"iteration": 1000, "weight": -1.0}, ], }, ), diff --git a/src/mjlab/tasks/tracking/scripts/evaluate.py b/src/mjlab/tasks/tracking/scripts/evaluate.py index 800af4e0e..3e333c6eb 100644 --- a/src/mjlab/tasks/tracking/scripts/evaluate.py +++ b/src/mjlab/tasks/tracking/scripts/evaluate.py @@ -71,7 +71,9 @@ def run_evaluate(task_id: str, cfg: EvaluateConfig) -> dict[str, float]: env_cfg.events.pop("push_robot", None) env_cfg.scene.num_envs = cfg.num_envs - env = ManagerBasedRlEnv(cfg=env_cfg, device=device) + env = ManagerBasedRlEnv( + cfg=env_cfg, device=device, num_steps_per_env=agent_cfg.num_steps_per_env + ) env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions) log_root_path = (Path("logs") / "rsl_rl" / agent_cfg.experiment_name).resolve() diff --git a/src/mjlab/tasks/velocity/velocity_env_cfg.py b/src/mjlab/tasks/velocity/velocity_env_cfg.py index 110ce0423..f4a5082e6 100644 --- a/src/mjlab/tasks/velocity/velocity_env_cfg.py +++ b/src/mjlab/tasks/velocity/velocity_env_cfg.py @@ -363,9 +363,9 @@ def make_velocity_env_cfg() -> ManagerBasedRlEnvCfg: params={ "command_name": "twist", "velocity_stages": [ - {"step": 0, "lin_vel_x": (-1.0, 1.0), "ang_vel_z": (-0.5, 0.5)}, - {"step": 5000 * 24, "lin_vel_x": (-1.5, 2.0), "ang_vel_z": (-0.7, 0.7)}, - {"step": 10000 * 24, "lin_vel_x": (-2.0, 3.0)}, + {"iteration": 0, "lin_vel_x": (-1.0, 1.0), "ang_vel_z": (-0.5, 0.5)}, + {"iteration": 5000, "lin_vel_x": (-1.5, 2.0), "ang_vel_z": (-0.7, 0.7)}, + {"iteration": 10000, "lin_vel_x": (-2.0, 3.0)}, ], }, ), diff --git a/tests/test_curriculum_manager.py b/tests/test_curriculum_manager.py new file mode 100644 index 000000000..9628ee402 --- /dev/null +++ b/tests/test_curriculum_manager.py @@ -0,0 +1,146 @@ +"""Tests for resolve_curriculum_iterations() and CurriculumManager.""" + +import copy +from unittest.mock import Mock + +import pytest + +from mjlab.managers.curriculum_manager import ( + CurriculumManager, + CurriculumTermCfg, + resolve_curriculum_iterations, +) + + +def _dummy_func(*args, **kwargs): + pass + + +def _make_curriculum(stages_key, stages): + """Build a minimal curriculum dict with one term.""" + return { + "term": CurriculumTermCfg( + func=_dummy_func, + params={"some_name": "twist", stages_key: stages}, + ) + } + + +def test_converts_iteration_to_step(): + stages = [ + {"iteration": 0, "lin_vel_x": (-1.0, 1.0)}, + {"iteration": 5000, "lin_vel_x": (-2.0, 2.0)}, + ] + curriculum = _make_curriculum("velocity_stages", stages) + resolve_curriculum_iterations(curriculum, num_steps_per_env=24) + + resolved = curriculum["term"].params["velocity_stages"] + assert resolved[0] == {"step": 0, "lin_vel_x": (-1.0, 1.0)} + assert resolved[1] == {"step": 120000, "lin_vel_x": (-2.0, 2.0)} + assert all("iteration" not in stage for stage in resolved) + + +def test_leaves_step_based_stages_untouched(): + stages = [ + {"step": 0, "weight": -0.01}, + {"step": 12000, "weight": -1.0}, + ] + curriculum = _make_curriculum("weight_stages", stages) + original = copy.deepcopy(curriculum) + resolve_curriculum_iterations(curriculum, num_steps_per_env=24) + + assert ( + curriculum["term"].params["weight_stages"] + == original["term"].params["weight_stages"] + ) + + +def test_leaves_non_stage_params_untouched(): + curriculum = _make_curriculum( + "velocity_stages", + [{"iteration": 100, "lin_vel_x": (-1.0, 1.0)}], + ) + resolve_curriculum_iterations(curriculum, num_steps_per_env=24) + assert curriculum["term"].params["some_name"] == "twist" + + +def test_raises_on_ambiguous_stage(): + stages = [{"iteration": 100, "step": 2400, "weight": -1.0}] + curriculum = _make_curriculum("weight_stages", stages) + with pytest.raises(ValueError, match="both 'iteration' and 'step'"): + resolve_curriculum_iterations(curriculum, num_steps_per_env=24) + + +def test_velocity_env_cfg_stages_resolve(): + """End-to-end: velocity config's command_vel stages resolve correctly.""" + from mjlab.tasks.velocity.velocity_env_cfg import make_velocity_env_cfg + + env_cfg = make_velocity_env_cfg() + resolve_curriculum_iterations(env_cfg.curriculum, num_steps_per_env=24) + + stages = env_cfg.curriculum["command_vel"].params["velocity_stages"] + assert stages[0]["step"] == 0 + assert stages[1]["step"] == 5000 * 24 + assert stages[2]["step"] == 10000 * 24 + assert all("iteration" not in stage for stage in stages) + + +def test_lift_cube_env_cfg_stages_resolve(): + """End-to-end: lift cube config's weight stages resolve correctly.""" + from mjlab.tasks.manipulation.lift_cube_env_cfg import ( + make_lift_cube_env_cfg, + ) + + env_cfg = make_lift_cube_env_cfg() + resolve_curriculum_iterations(env_cfg.curriculum, num_steps_per_env=24) + + stages = env_cfg.curriculum["joint_vel_hinge_weight"].params["weight_stages"] + assert stages[0]["step"] == 0 + assert stages[1]["step"] == 500 * 24 + assert stages[2]["step"] == 1000 * 24 + assert all("iteration" not in stage for stage in stages) + + +def _make_mock_env(): + env = Mock() + env.num_envs = 1 + env.device = "cpu" + env.scene = None + return env + + +def test_curriculum_manager_resolves_iterations_via_num_steps_per_env(): + stages = [ + {"iteration": 0, "lin_vel_x": (-1.0, 1.0)}, + {"iteration": 5000, "lin_vel_x": (-2.0, 2.0)}, + ] + curriculum = _make_curriculum("velocity_stages", stages) + mgr = CurriculumManager(curriculum, _make_mock_env(), num_steps_per_env=24) + + resolved = mgr.cfg["term"].params["velocity_stages"] + assert resolved[0] == {"step": 0, "lin_vel_x": (-1.0, 1.0)} + assert resolved[1] == {"step": 120000, "lin_vel_x": (-2.0, 2.0)} + assert all("iteration" not in stage for stage in resolved) + + +def test_curriculum_manager_leaves_iterations_when_num_steps_not_given(): + stages = [ + {"iteration": 0, "lin_vel_x": (-1.0, 1.0)}, + {"iteration": 5000, "lin_vel_x": (-2.0, 2.0)}, + ] + curriculum = _make_curriculum("velocity_stages", stages) + mgr = CurriculumManager(curriculum, _make_mock_env()) + + resolved = mgr.cfg["term"].params["velocity_stages"] + assert all("iteration" in stage for stage in resolved) + assert all("step" not in stage for stage in resolved) + + +def test_curriculum_manager_does_not_mutate_input_cfg(): + stages = [{"iteration": 100, "lin_vel_x": (-1.0, 1.0)}] + curriculum = _make_curriculum("velocity_stages", stages) + original_stages = copy.deepcopy(stages) + + CurriculumManager(curriculum, _make_mock_env(), num_steps_per_env=24) + + assert curriculum["term"].params["velocity_stages"] == original_stages