diff --git a/.gitignore b/.gitignore index e2f0df855..649f62137 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ MUJOCO_LOG.TXT mujoco_menagerie checkpoints/ +logs diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index 96b7dcccf..71228eef8 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -167,6 +167,12 @@ def render_callback(_, state): # Build RSL-RL config train_cfg = get_rl_config(_ENV_NAME.value) + obs_size = raw_env.observation_size + if isinstance(obs_size, dict): + train_cfg.obs_groups = {"policy": ["state"], "critic": ["privileged_state"]} + else: + train_cfg.obs_groups = {"policy": ["state"], "critic": ["state"]} + # Overwrite default config with flags train_cfg.seed = _SEED.value train_cfg.run_name = exp_name diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index 3d70bff99..566a7d605 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -17,6 +17,7 @@ from collections import deque import functools import os +from typing import Any import jax import numpy as np @@ -31,6 +32,10 @@ torch = None from mujoco_playground._src import wrapper +try: + from tensordict import TensorDict +except ImportError: + TensorDict = None def _jax_to_torch(tensor): @@ -158,8 +163,10 @@ def step(self, action): if self.asymmetric_obs: obs = _jax_to_torch(self.env_state.obs["state"]) critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) + obs = {"state": obs, "privileged_state": critic_obs} else: obs = _jax_to_torch(self.env_state.obs) + obs = {"state": obs} reward = _jax_to_torch(self.env_state.reward) done = _jax_to_torch(self.env_state.done) info = self.env_state.info @@ -187,6 +194,7 @@ def step(self, action): if k not in info_ret["log"]: info_ret["log"][k] = _jax_to_torch(v).float().mean().item() + obs = TensorDict(obs, batch_size=[self.num_envs]) return obs, reward, done, info_ret def reset(self): @@ -195,23 +203,15 @@ def reset(self): if self.asymmetric_obs: obs = _jax_to_torch(self.env_state.obs["state"]) - # critic_obs = jax_to_torch(self.env_state.obs["privileged_state"]) + critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) + obs = {"state": obs, "privileged_state": critic_obs} else: obs = _jax_to_torch(self.env_state.obs) - return obs - - def reset_with_critic_obs(self): - self.env_state = self.reset_fn(self.key_reset) - obs = _jax_to_torch(self.env_state.obs["state"]) - critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) - return obs, critic_obs + obs = {"state": obs} + return TensorDict(obs, batch_size=[self.num_envs]) def get_observations(self): - if self.asymmetric_obs: - obs, critic_obs = self.reset_with_critic_obs() - return obs, {"observations": {"critic": critic_obs}} - else: - return self.reset(), {"observations": {}} + return self.reset() def render(self, mode="human"): # pylint: disable=unused-argument if self.render_callback is not None: diff --git a/pyproject.toml b/pyproject.toml index 1365ddf8a..a693e088e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "orbax-checkpoint>=0.11.22", "tqdm", "warp-lang>=1.9.0.dev", - "wandb", ] keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"] @@ -75,9 +74,14 @@ dev = [ "pylint", "pytest-xdist", ] +learning = [ + "rsl-rl-lib>=3.0.0", + "wandb", +] all = [ "playground[dev]", "playground[notebooks]", + "playground[learning]", ] [tool.hatch.metadata]