Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mjlab/envs/manager_based_rl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,13 @@ def __init__(
cfg: ManagerBasedRlEnvCfg,
device: str,
render_mode: str | None = None,
*,
num_steps_per_env: int | None = None,
**kwargs,
) -> None:
# Initialize base environment state.
self.cfg = cfg
self._num_steps_per_env = num_steps_per_env
if self.cfg.seed is not None:
self.cfg.seed = self.seed(self.cfg.seed)
self._sim_step_counter = 0
Expand Down Expand Up @@ -299,7 +302,9 @@ def load_managers(self) -> None:
)
print_info(f"[INFO] {self.reward_manager}")
if len(self.cfg.curriculum) > 0:
self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self)
self.curriculum_manager = CurriculumManager(
self.cfg.curriculum, self, num_steps_per_env=self._num_steps_per_env
)
else:
self.curriculum_manager = NullCurriculumManager()
print_info(f"[INFO] {self.curriculum_manager}")
Expand Down
39 changes: 38 additions & 1 deletion src/mjlab/managers/curriculum_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,32 @@ class CurriculumTermCfg(ManagerTermBaseCfg):
pass


def resolve_curriculum_iterations(
curriculum: dict[str, CurriculumTermCfg],
num_steps_per_env: int,
) -> None:
"""Convert ``"iteration"`` keys to ``"step"`` keys in curriculum stages.

This allows curriculum configs to express thresholds in training
iterations (more intuitive) rather than raw environment steps.
The conversion is ``step = iteration * num_steps_per_env``.

Modifies *curriculum* in-place. Raises :class:`ValueError` if a stage
dict contains both ``"iteration"`` and ``"step"`` keys.
"""
for term_cfg in curriculum.values():
for value in term_cfg.params.values():
if not isinstance(value, list):
continue
for item in value:
if not isinstance(item, dict):
continue
if "iteration" in item and "step" in item:
raise ValueError(f"Curriculum stage has both 'iteration' and 'step': {item}")
if "iteration" in item:
item["step"] = item.pop("iteration") * num_steps_per_env


class CurriculumManager(ManagerBase):
"""Manages curriculum learning for the environment.

Expand All @@ -36,12 +62,23 @@ class CurriculumManager(ManagerBase):

_env: ManagerBasedRlEnv

def __init__(self, cfg: dict[str, CurriculumTermCfg], env: ManagerBasedRlEnv):
def __init__(
self,
cfg: dict[str, CurriculumTermCfg],
env: ManagerBasedRlEnv,
*,
num_steps_per_env: int | None = None,
):
self._term_names: list[str] = list()
self._term_cfgs: list[CurriculumTermCfg] = list()
self._class_term_cfgs: list[CurriculumTermCfg] = list()

# Work on a deep copy so modifications are local to this manager.
self.cfg = deepcopy(cfg)

if num_steps_per_env is not None:
resolve_curriculum_iterations(self.cfg, num_steps_per_env)

super().__init__(env)

self._curriculum_state = dict()
Expand Down
7 changes: 6 additions & 1 deletion src/mjlab/scripts/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ def run_play(task_id: str, cfg: PlayConfig):
print(
"[WARN] Video recording with dummy agents is disabled (no checkpoint/log_dir)."
)
env = ManagerBasedRlEnv(cfg=env_cfg, device=device, render_mode=render_mode)
env = ManagerBasedRlEnv(
cfg=env_cfg,
device=device,
render_mode=render_mode,
num_steps_per_env=agent_cfg.num_steps_per_env,
)

if TRAINED_MODE and cfg.video:
print("[INFO] Recording videos during play")
Expand Down
5 changes: 4 additions & 1 deletion src/mjlab/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def run_train(task_id: str, cfg: TrainConfig, log_dir: Path) -> None:
print(f"[INFO] Logging experiment in directory: {log_dir}")

env = ManagerBasedRlEnv(
cfg=cfg.env, device=device, render_mode="rgb_array" if cfg.video else None
cfg=cfg.env,
device=device,
render_mode="rgb_array" if cfg.video else None,
num_steps_per_env=cfg.agent.num_steps_per_env,
)

log_root_path = log_dir.parent # Go up from specific run dir to experiment dir.
Expand Down
6 changes: 3 additions & 3 deletions src/mjlab/tasks/manipulation/lift_cube_env_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def make_lift_cube_env_cfg() -> ManagerBasedRlEnvCfg:
params={
"reward_name": "joint_vel_hinge",
"weight_stages": [
{"step": 0, "weight": -0.01},
{"step": 500 * 24, "weight": -0.1},
{"step": 1000 * 24, "weight": -1.0},
{"iteration": 0, "weight": -0.01},
{"iteration": 500, "weight": -0.1},
{"iteration": 1000, "weight": -1.0},
],
},
),
Expand Down
4 changes: 3 additions & 1 deletion src/mjlab/tasks/tracking/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def run_evaluate(task_id: str, cfg: EvaluateConfig) -> dict[str, float]:
env_cfg.events.pop("push_robot", None)
env_cfg.scene.num_envs = cfg.num_envs

env = ManagerBasedRlEnv(cfg=env_cfg, device=device)
env = ManagerBasedRlEnv(
cfg=env_cfg, device=device, num_steps_per_env=agent_cfg.num_steps_per_env
)
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

log_root_path = (Path("logs") / "rsl_rl" / agent_cfg.experiment_name).resolve()
Expand Down
6 changes: 3 additions & 3 deletions src/mjlab/tasks/velocity/velocity_env_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def make_velocity_env_cfg() -> ManagerBasedRlEnvCfg:
params={
"command_name": "twist",
"velocity_stages": [
{"step": 0, "lin_vel_x": (-1.0, 1.0), "ang_vel_z": (-0.5, 0.5)},
{"step": 5000 * 24, "lin_vel_x": (-1.5, 2.0), "ang_vel_z": (-0.7, 0.7)},
{"step": 10000 * 24, "lin_vel_x": (-2.0, 3.0)},
{"iteration": 0, "lin_vel_x": (-1.0, 1.0), "ang_vel_z": (-0.5, 0.5)},
{"iteration": 5000, "lin_vel_x": (-1.5, 2.0), "ang_vel_z": (-0.7, 0.7)},
{"iteration": 10000, "lin_vel_x": (-2.0, 3.0)},
],
},
),
Expand Down
146 changes: 146 additions & 0 deletions tests/test_curriculum_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Tests for resolve_curriculum_iterations() and CurriculumManager."""

import copy
from unittest.mock import Mock

import pytest

from mjlab.managers.curriculum_manager import (
CurriculumManager,
CurriculumTermCfg,
resolve_curriculum_iterations,
)


def _dummy_func(*args, **kwargs):
pass


def _make_curriculum(stages_key, stages):
"""Build a minimal curriculum dict with one term."""
return {
"term": CurriculumTermCfg(
func=_dummy_func,
params={"some_name": "twist", stages_key: stages},
)
}


def test_converts_iteration_to_step():
stages = [
{"iteration": 0, "lin_vel_x": (-1.0, 1.0)},
{"iteration": 5000, "lin_vel_x": (-2.0, 2.0)},
]
curriculum = _make_curriculum("velocity_stages", stages)
resolve_curriculum_iterations(curriculum, num_steps_per_env=24)

resolved = curriculum["term"].params["velocity_stages"]
assert resolved[0] == {"step": 0, "lin_vel_x": (-1.0, 1.0)}
assert resolved[1] == {"step": 120000, "lin_vel_x": (-2.0, 2.0)}
assert all("iteration" not in stage for stage in resolved)


def test_leaves_step_based_stages_untouched():
stages = [
{"step": 0, "weight": -0.01},
{"step": 12000, "weight": -1.0},
]
curriculum = _make_curriculum("weight_stages", stages)
original = copy.deepcopy(curriculum)
resolve_curriculum_iterations(curriculum, num_steps_per_env=24)

assert (
curriculum["term"].params["weight_stages"]
== original["term"].params["weight_stages"]
)


def test_leaves_non_stage_params_untouched():
curriculum = _make_curriculum(
"velocity_stages",
[{"iteration": 100, "lin_vel_x": (-1.0, 1.0)}],
)
resolve_curriculum_iterations(curriculum, num_steps_per_env=24)
assert curriculum["term"].params["some_name"] == "twist"


def test_raises_on_ambiguous_stage():
stages = [{"iteration": 100, "step": 2400, "weight": -1.0}]
curriculum = _make_curriculum("weight_stages", stages)
with pytest.raises(ValueError, match="both 'iteration' and 'step'"):
resolve_curriculum_iterations(curriculum, num_steps_per_env=24)


def test_velocity_env_cfg_stages_resolve():
"""End-to-end: velocity config's command_vel stages resolve correctly."""
from mjlab.tasks.velocity.velocity_env_cfg import make_velocity_env_cfg

env_cfg = make_velocity_env_cfg()
resolve_curriculum_iterations(env_cfg.curriculum, num_steps_per_env=24)

stages = env_cfg.curriculum["command_vel"].params["velocity_stages"]
assert stages[0]["step"] == 0
assert stages[1]["step"] == 5000 * 24
assert stages[2]["step"] == 10000 * 24
assert all("iteration" not in stage for stage in stages)


def test_lift_cube_env_cfg_stages_resolve():
"""End-to-end: lift cube config's weight stages resolve correctly."""
from mjlab.tasks.manipulation.lift_cube_env_cfg import (
make_lift_cube_env_cfg,
)

env_cfg = make_lift_cube_env_cfg()
resolve_curriculum_iterations(env_cfg.curriculum, num_steps_per_env=24)

stages = env_cfg.curriculum["joint_vel_hinge_weight"].params["weight_stages"]
assert stages[0]["step"] == 0
assert stages[1]["step"] == 500 * 24
assert stages[2]["step"] == 1000 * 24
assert all("iteration" not in stage for stage in stages)


def _make_mock_env():
env = Mock()
env.num_envs = 1
env.device = "cpu"
env.scene = None
return env


def test_curriculum_manager_resolves_iterations_via_num_steps_per_env():
stages = [
{"iteration": 0, "lin_vel_x": (-1.0, 1.0)},
{"iteration": 5000, "lin_vel_x": (-2.0, 2.0)},
]
curriculum = _make_curriculum("velocity_stages", stages)
mgr = CurriculumManager(curriculum, _make_mock_env(), num_steps_per_env=24)

resolved = mgr.cfg["term"].params["velocity_stages"]
assert resolved[0] == {"step": 0, "lin_vel_x": (-1.0, 1.0)}
assert resolved[1] == {"step": 120000, "lin_vel_x": (-2.0, 2.0)}
assert all("iteration" not in stage for stage in resolved)


def test_curriculum_manager_leaves_iterations_when_num_steps_not_given():
stages = [
{"iteration": 0, "lin_vel_x": (-1.0, 1.0)},
{"iteration": 5000, "lin_vel_x": (-2.0, 2.0)},
]
curriculum = _make_curriculum("velocity_stages", stages)
mgr = CurriculumManager(curriculum, _make_mock_env())

resolved = mgr.cfg["term"].params["velocity_stages"]
assert all("iteration" in stage for stage in resolved)
assert all("step" not in stage for stage in resolved)


def test_curriculum_manager_does_not_mutate_input_cfg():
stages = [{"iteration": 100, "lin_vel_x": (-1.0, 1.0)}]
curriculum = _make_curriculum("velocity_stages", stages)
original_stages = copy.deepcopy(stages)

CurriculumManager(curriculum, _make_mock_env(), num_steps_per_env=24)

assert curriculum["term"].params["velocity_stages"] == original_stages