Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ See documentation for the full list of included features.

**RL Algorithms**:
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Noisy Cross Entropy Method (CEM)](http://dx.doi.org/10.1162/neco.2006.18.12.2936)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
Expand Down
5 changes: 5 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ Release 2.7.0 (2025-07-25)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.7.0
- Changed default policy architecture for ARS/CEM to ``[32]`` instead of ``[64, 64]``

New Features:
^^^^^^^^^^^^^
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
- Added noisy Cross Entropy Method (CEM)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -357,6 +359,9 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Release 1.5.0 (2022-03-25)
-------------------------------

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/common/recurrent/policies.py$
| sb3_contrib/common/recurrent/buffers.py$
| tests/test_train_eval_mode.py$
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.cem import CEM
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
Expand All @@ -15,6 +16,7 @@

__all__ = [
"ARS",
"CEM",
"QRDQN",
"TQC",
"TRPO",
Expand Down
222 changes: 22 additions & 200 deletions sb3_contrib/ars/ars.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
import copy
import sys
import time
import warnings
from functools import partial
from typing import Any, ClassVar, Optional, TypeVar, Union

import numpy as np
import torch as th
import torch.nn.utils
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import FloatSchedule, safe_mean
from stable_baselines3.common.utils import FloatSchedule

from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.policies import ESLinearPolicy, ESPolicy
from sb3_contrib.common.population_based_algorithm import PopulationBasedAlgorithm
from sb3_contrib.common.vec_env.async_eval import AsyncEval

SelfARS = TypeVar("SelfARS", bound="ARS")


class ARS(BaseAlgorithm):
class ARS(PopulationBasedAlgorithm):
"""
Augmented Random Search: https://arxiv.org/abs/1803.07055

Expand Down Expand Up @@ -51,13 +43,14 @@ class ARS(BaseAlgorithm):
"""

policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
"LinearPolicy": LinearPolicy,
"MlpPolicy": ESPolicy,
"LinearPolicy": ESLinearPolicy,
}
weights: th.Tensor # Need to call init model to initialize weights

def __init__(
self,
policy: Union[str, type[ARSPolicy]],
policy: Union[str, type[ESPolicy]],
env: Union[GymEnv, str],
n_delta: int = 8,
n_top: Optional[int] = None,
Expand All @@ -78,20 +71,20 @@ def __init__(
policy,
env,
learning_rate=learning_rate,
pop_size=2 * n_delta,
alive_bonus_offset=alive_bonus_offset,
n_eval_episodes=n_eval_episodes,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
supported_action_spaces=(spaces.Box, spaces.Discrete),
support_multi_env=True,
seed=seed,
)

self.n_delta = n_delta
self.pop_size = 2 * n_delta
self.delta_std_schedule = FloatSchedule(delta_std)
self.n_eval_episodes = n_eval_episodes

if n_top is None:
n_top = n_delta
Expand All @@ -103,13 +96,7 @@ def __init__(

self.n_top = n_top

self.alive_bonus_offset = alive_bonus_offset
self.zero_policy = zero_policy
self.weights = None # Need to call init model to initialize weight
self.processes = None
# Keep track of how many steps where elapsed before a new rollout
# Important for syncing observation normalization between workers
self.old_count = 0

if _init_setup_model:
self._setup_model()
Expand All @@ -125,137 +112,7 @@ def _setup_model(self) -> None:

if self.zero_policy:
self.weights = th.zeros_like(self.weights, requires_grad=False)
self.policy.load_from_vector(self.weights.cpu())

def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None:
"""
Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length).

:param episode_rewards: List containing per-episode rewards
:param episode_lengths: List containing per-episode lengths (in number of steps)
"""
# Mimic Monitor Wrapper
infos = [
{"episode": {"r": episode_reward, "l": episode_length}}
for episode_reward, episode_length in zip(episode_rewards, episode_lengths)
]

self._update_info_buffer(infos)

def _trigger_callback(
self,
_locals: dict[str, Any],
_globals: dict[str, Any],
callback: BaseCallback,
n_envs: int,
) -> None:
"""
Callback passed to the ``evaluate_policy()`` helper
in order to increment the number of timesteps
and trigger events in the single process version.

:param _locals:
:param _globals:
:param callback: Callback that will be called at every step
:param n_envs: Number of environments
"""
self.num_timesteps += n_envs
callback.on_step()

def evaluate_candidates(
self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval]
) -> th.Tensor:
"""
Evaluate each candidate.

:param candidate_weights: The candidate weights to be evaluated.
:param callback: Callback that will be called at each step
(or after evaluation in the multiprocess version)
:param async_eval: The object for asynchronous evaluation of candidates.
:return: The episodic return for each candidate.
"""

batch_steps = 0
# returns == sum of rewards
candidate_returns = th.zeros(self.pop_size, device=self.device)
train_policy = copy.deepcopy(self.policy)
# Empty buffer to show only mean over one iteration (one set of candidates) in the logs
self.ep_info_buffer = []
callback.on_rollout_start()

if async_eval is not None:
# Multiprocess asynchronous version
async_eval.send_jobs(candidate_weights, self.pop_size)
results = async_eval.get_results()

for weights_idx, (episode_rewards, episode_lengths) in results:
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += np.sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)

# Combine the filter stats of each process for normalization
for worker_obs_rms in async_eval.get_obs_rms():
if self._vec_normalize_env is not None:
# worker_obs_rms.count -= self.old_count
self._vec_normalize_env.obs_rms.combine(worker_obs_rms)
# Hack: don't count timesteps twice (between the two are synced)
# otherwise it will lead to overflow,
# in practice we would need two RunningMeanStats
self._vec_normalize_env.obs_rms.count -= self.old_count

# Synchronise VecNormalize if needed
if self._vec_normalize_env is not None:
async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy())
self.old_count = self._vec_normalize_env.obs_rms.count

# Hack to have Callback events
for _ in range(batch_steps // len(async_eval.remotes)):
self.num_timesteps += len(async_eval.remotes)
callback.on_step()
else:
# Single process, synchronous version
for weights_idx in range(self.pop_size):
# Load current candidate weights
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
# Evaluate the candidate
episode_rewards, episode_lengths = evaluate_policy(
train_policy,
self.env,
n_eval_episodes=self.n_eval_episodes,
return_episode_rewards=True,
# Increment num_timesteps too (slight mismatch with multi envs)
callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs),
warn=False,
)
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)

# Note: we increment the num_timesteps inside the evaluate_policy()
# however when using multiple environments, there will be a slight
# mismatch between the number of timesteps used and the number
# of calls to the step() method (cf. implementation of evaluate_policy())
# self.num_timesteps += batch_steps

callback.on_rollout_end()

return candidate_returns

def dump_logs(self) -> None:
"""
Dump information to the logger.
"""
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
self.policy.load_from_vector(self.weights.cpu().numpy())

def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None:
"""
Expand Down Expand Up @@ -295,7 +152,7 @@ def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]
step_size = learning_rate / (self.n_top * return_std + 1e-6)
# Approximate gradient step
self.weights = self.weights + step_size * ((plus_returns - minus_returns) @ deltas)
self.policy.load_from_vector(self.weights.cpu())
self.policy.load_from_vector(self.weights.cpu().numpy())

self.logger.record("train/iterations", self._n_updates, exclude="tensorboard")
self.logger.record("train/delta_std", delta_std)
Expand All @@ -305,7 +162,7 @@ def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]

self._n_updates += 1

def learn(
def learn( # type: ignore[override]
self: SelfARS,
total_timesteps: int,
callback: MaybeCallback = None,
Expand All @@ -328,47 +185,12 @@ def learn(
:return: the trained model
"""

total_steps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
async_eval=async_eval,
progress_bar=progress_bar,
)

callback.on_training_start(locals(), globals())

while self.num_timesteps < total_steps:
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
self._do_one_update(callback, async_eval)
if log_interval is not None and self._n_updates % log_interval == 0:
self.dump_logs()

if async_eval is not None:
async_eval.close()

callback.on_training_end()

return self

def set_parameters(
self,
load_path_or_dict: Union[str, dict[str, dict]],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
# Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0
params = None
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)

# Patch to load LinearPolicy saved using sb3-contrib < 1.7.0
# See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/122#issuecomment-1331981230
for name in {"weight", "bias"}:
if f"action_net.{name}" in params.get("policy", {}):
params["policy"][f"action_net.0.{name}"] = params["policy"][f"action_net.{name}"]
del params["policy"][f"action_net.{name}"]

super().set_parameters(params, exact_match=exact_match)
Loading