diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b0531b5c..3c2e69808 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ ### Bug fixes - Policies saved during during optimization with distributed Optuna load on new systems (@jkterry) - Fixed script for recording video that was not up to date with the enjoy script +- Added CEM support ### Documentation diff --git a/hyperparams/ars.yml b/hyperparams/ars.yml index e58d4fa3c..2457df5aa 100644 --- a/hyperparams/ars.yml +++ b/hyperparams/ars.yml @@ -18,12 +18,14 @@ Pendulum-v1: &pendulum-params policy_kwargs: "dict(net_arch=[16])" zero_policy: False -# TO BE Tuned +# Almost Tuned LunarLander-v2: <<: *pendulum-params - n_delta: 6 - n_top: 1 - n_timesteps: !!float 2e6 + n_timesteps: !!float 4e6 + n_delta: 64 + n_top: 4 + policy_kwargs: "dict(net_arch=[])" + # delta_std: 0.08 # Tuned LunarLanderContinuous-v2: @@ -215,4 +217,3 @@ A1Jumping-v0: # alive_bonus_offset: -1 normalize: "dict(norm_obs=True, norm_reward=False)" # policy_kwargs: "dict(net_arch=[16])" - diff --git a/hyperparams/cem.yml b/hyperparams/cem.yml new file mode 100644 index 000000000..fe834a282 --- /dev/null +++ b/hyperparams/cem.yml @@ -0,0 +1,190 @@ +# Tuned +CartPole-v1: + n_envs: 1 + n_timesteps: !!float 1e5 + policy: 'LinearPolicy' + pop_size: 4 + n_top: 2 + +# Tuned +Pendulum-v1: &pendulum-params + n_envs: 1 + n_timesteps: !!float 2e6 + policy: 'MlpPolicy' + normalize: "dict(norm_obs=True, norm_reward=False)" + pop_size: 8 + n_top: 4 + policy_kwargs: "dict(net_arch=[16])" + +# Tuned +LunarLander-v2: + <<: *pendulum-params + n_timesteps: !!float 1e6 + pop_size: 32 + n_top: 6 + policy_kwargs: "dict(net_arch=[])" + +# Tuned +LunarLanderContinuous-v2: + <<: *pendulum-params + n_timesteps: !!float 2e6 + pop_size: 16 + n_top: 4 + +# Tuned +Acrobot-v1: + <<: *pendulum-params + n_timesteps: !!float 5e5 + +# Tuned +MountainCar-v0: + <<: *pendulum-params + pop_size: 16 + n_timesteps: !!float 5e5 + +# Tuned +MountainCarContinuous-v0: + <<: *pendulum-params + n_timesteps: !!float 5e5 + +# === Pybullet Envs === +# To be tuned +HalfCheetahBulletEnv-v0: &pybullet-defaults + <<: *pendulum-params + n_timesteps: !!float 1e6 + pop_size: 64 + n_top: 6 + extra_noise_std: 0.1 + +# To be tuned +AntBulletEnv-v0: + n_envs: 1 + policy: 'MlpPolicy' + n_timesteps: !!float 7.5e7 + learning_rate: !!float 0.02 + delta_std: !!float 0.03 + n_delta: 32 + n_top: 32 + alive_bonus_offset: 0 + normalize: "dict(norm_obs=True, norm_reward=False)" + policy_kwargs: "dict(net_arch=[128, 64])" + zero_policy: False + + +Walker2DBulletEnv-v0: + policy: 'MlpPolicy' + n_timesteps: !!float 7.5e7 + learning_rate: !!float 0.03 + delta_std: !!float 0.025 + n_delta: 40 + n_top: 30 + alive_bonus_offset: -1 + normalize: "dict(norm_obs=True, norm_reward=False)" + policy_kwargs: "dict(net_arch=[64, 64])" + zero_policy: False + +# Tuned +HopperBulletEnv-v0: + <<: *pendulum-params + n_timesteps: !!float 1e6 + pop_size: 64 + n_top: 6 + extra_noise_std: 0.1 + alive_bonus_offset: -1 + +ReacherBulletEnv-v0: + <<: *pybullet-defaults + n_timesteps: !!float 1e6 + +# === Mujoco Envs === +# Tuned +Swimmer-v3: + <<: *pendulum-params + n_timesteps: !!float 1e6 + n_top: 2 + +Hopper-v3: + <<: *pendulum-params + n_timesteps: !!float 1e6 + pop_size: 64 + n_top: 6 + extra_noise_std: 0.1 + alive_bonus_offset: -1 + + +HalfCheetah-v3: + <<: *pendulum-params + n_timesteps: !!float 1e6 + pop_size: 50 + n_top: 6 + extra_noise_std: 0.1 + +Walker2d-v3: + n_envs: 1 + policy: 'LinearPolicy' + n_timesteps: !!float 7.5e7 + pop_size: 40 + n_top: 30 + alive_bonus_offset: -1 + normalize: "dict(norm_obs=True, norm_reward=False)" + +# TO BE TUNED +Ant-v3: + <<: *pendulum-params + # For hyperparameter optimization, to alive_bonus_offset + # taken into account (reward_offset): + # env_wrapper: + # - utils.wrappers.DoneOnSuccessWrapper: + # reward_offset: -1.0 + # - stable_baselines3.common.monitor.Monitor + n_timesteps: !!float 2e6 + pop_size: 64 + n_top: 6 + # extra_noise_std: 0.1 + # noise_multiplier: 0.999 + alive_bonus_offset: -1 + policy_kwargs: "dict(net_arch=[])" + + +Humanoid-v3: + n_envs: 1 + policy: 'LinearPolicy' + n_timesteps: !!float 2.5e8 + pop_size: 256 + n_top: 256 + alive_bonus_offset: -5 + normalize: "dict(norm_obs=True, norm_reward=False)" + +BipedalWalker-v3: + n_envs: 1 + policy: 'MlpPolicy' + n_timesteps: !!float 1e8 + pop_size: 64 + n_top: 32 + alive_bonus_offset: -0.1 + normalize: "dict(norm_obs=True, norm_reward=False)" + policy_kwargs: "dict(net_arch=[16])" + +# TO Be Tuned +BipedalWalkerHardcore-v3: + n_envs: 1 + policy: 'MlpPolicy' + n_timesteps: !!float 5e8 + pop_size: 64 + n_top: 32 + alive_bonus_offset: -0.1 + normalize: "dict(norm_obs=True, norm_reward=False)" + policy_kwargs: "dict(net_arch=[16])" + +A1Walking-v0: + <<: *pendulum-params + n_timesteps: !!float 2e6 + +A1Jumping-v0: + policy: 'LinearPolicy' + n_timesteps: !!float 7.5e7 + pop_size: 80 + n_top: 30 + # alive_bonus_offset: -1 + normalize: "dict(norm_obs=True, norm_reward=False)" + # policy_kwargs: "dict(net_arch=[16])" diff --git a/utils/callbacks.py b/utils/callbacks.py index f156f0e5d..7490705af 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -47,7 +47,6 @@ def _on_step(self) -> bool: if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: super(TrialEvalCallback, self)._on_step() self.eval_idx += 1 - # report best or report current ? # report num_timesteps or elasped time ? self.trial.report(self.last_mean_reward, self.eval_idx) # Prune trial if need diff --git a/utils/exp_manager.py b/utils/exp_manager.py index 1bf718aff..10e5f62e1 100644 --- a/utils/exp_manager.py +++ b/utils/exp_manager.py @@ -167,7 +167,7 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]: self.create_callbacks() # Create env to have access to action space for action noise - n_envs = 1 if self.algo == "ars" else self.n_envs + n_envs = 1 if self.algo in ["ars", "cem"] else self.n_envs env = self.create_envs(n_envs, no_log=False) self._hyperparams = self._preprocess_action_noise(hyperparams, saved_hyperparams, env) @@ -200,8 +200,8 @@ def learn(self, model: BaseAlgorithm) -> None: if len(self.callbacks) > 0: kwargs["callback"] = self.callbacks - # Special case for ARS - if self.algo == "ars" and self.n_envs > 1: + # Special case for ARS and CEM + if self.algo in ["ars", "cem"] and self.n_envs > 1: kwargs["async_eval"] = AsyncEval( [lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy ) @@ -625,7 +625,7 @@ def objective(self, trial: optuna.Trial) -> float: sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial) kwargs.update(sampled_hyperparams) - n_envs = 1 if self.algo == "ars" else self.n_envs + n_envs = 1 if self.algo in ["ars", "cem"] else self.n_envs env = self.create_envs(n_envs, no_log=True) # By default, do not activate verbose output to keep @@ -667,8 +667,8 @@ def objective(self, trial: optuna.Trial) -> float: callbacks.append(eval_callback) learn_kwargs = {} - # Special case for ARS - if self.algo == "ars" and self.n_envs > 1: + # Special case for ARS and CEM + if self.algo in ["ars", "cem"] and self.n_envs > 1: learn_kwargs["async_eval"] = AsyncEval( [lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy ) diff --git a/utils/hyperparams_opt.py b/utils/hyperparams_opt.py index 81add5858..459debc20 100644 --- a/utils/hyperparams_opt.py +++ b/utils/hyperparams_opt.py @@ -487,7 +487,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]: :return: """ # n_eval_episodes = trial.suggest_categorical("n_eval_episodes", [1, 2]) - n_delta = trial.suggest_categorical("n_delta", [4, 8, 6, 32, 64]) + n_delta = trial.suggest_categorical("n_delta", [4, 8, 16, 32, 64]) # learning_rate = trial.suggest_categorical("learning_rate", [0.01, 0.02, 0.025, 0.03]) learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1) delta_std = trial.suggest_categorical("delta_std", [0.01, 0.02, 0.025, 0.03, 0.05, 0.1, 0.2, 0.3]) @@ -519,9 +519,48 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]: } +def sample_cem_params(trial: optuna.Trial) -> Dict[str, Any]: + """ + Sampler for CEM hyperparams. + :param trial: + :return: + """ + # n_eval_episodes = trial.suggest_categorical("n_eval_episodes", [1, 2]) + # pop_size = trial.suggest_categorical("pop_size", [4, 8, 16, 32, 64, 128]) + pop_size = trial.suggest_int("pop_size", 2, 130, step=4) + extra_noise_std = trial.suggest_categorical("extra_noise_std", [0.01, 0.02, 0.025, 0.03, 0.05, 0.1, 0.2, 0.3]) + noise_multiplier = trial.suggest_categorical("noise_multiplier", [0.99, 0.995, 0.998, 0.999, 0.9995, 0.9998]) + top_frac_size = trial.suggest_categorical("top_frac_size", [0.1, 0.2, 0.3, 0.4, 0.5, 0.8, 0.9, 1.0]) + n_top = max(int(top_frac_size * pop_size), 2) + # use_diagonal_covariance = trial.suggest_categorical("use_diagonal_covariance", [False, True]) + + # net_arch = trial.suggest_categorical("net_arch", ["linear", "tiny", "small"]) + + # Note: remove bias to be as the original linear policy + # and do not squash output + # Comment out when doing hyperparams search with linear policy only + # net_arch = { + # "linear": [], + # "tiny": [16], + # "small": [32], + # }[net_arch] + + # TODO: optimize the alive_bonus_offset too + + return { + # "n_eval_episodes": n_eval_episodes, + "pop_size": pop_size, + "extra_noise_std": extra_noise_std, + "noise_multiplier": noise_multiplier, + "n_top": n_top, + # "policy_kwargs": dict(net_arch=net_arch), + } + + HYPERPARAMS_SAMPLER = { "a2c": sample_a2c_params, "ars": sample_ars_params, + "cem": sample_cem_params, "ddpg": sample_ddpg_params, "dqn": sample_dqn_params, "qrdqn": sample_qrdqn_params, diff --git a/utils/utils.py b/utils/utils.py index 6072cc7cd..732e5ffb4 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -8,7 +8,7 @@ import stable_baselines3 as sb3 # noqa: F401 import torch as th # noqa: F401 import yaml -from sb3_contrib import ARS, QRDQN, TQC, TRPO +from sb3_contrib import ARS, CEM, QRDQN, TQC, TRPO from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.env_util import make_vec_env @@ -27,6 +27,7 @@ "td3": TD3, # SB3 Contrib, "ars": ARS, + "cem": CEM, "qrdqn": QRDQN, "tqc": TQC, "trpo": TRPO,