Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs]
# Install master version
Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
Changelog
==========

Release 1.4.1a1 (WIP)
Release 1.4.1a3 (WIP)
-------------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.
- Upgraded to Stable-Baselines3 >= 1.4.1a1
- Changed default policy architecture for ARS/CEM to ``[32]`` instead of ``[64, 64]``

New Features:
^^^^^^^^^^^^^
- Allow PPO to turn of advantage normalization (see `PR #61 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/61>`_) @vwxyzjn
- Added noisy Cross Entropy Method (CEM)

Bug Fixes:
^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.cem import CEM
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/ars/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
)

if net_arch is None:
net_arch = [64, 64]
net_arch = [32]

self.net_arch = net_arch
self.features_extractor = self.make_features_extractor()
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/cem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.cem.cem import CEM
from sb3_contrib.cem.policies import LinearPolicy, MlpPolicy
366 changes: 366 additions & 0 deletions sb3_contrib/cem/cem.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions sb3_contrib/cem/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from stable_baselines3.common.policies import register_policy

from sb3_contrib.ars.policies import ARSLinearPolicy, ARSPolicy

MlpPolicy = ARSPolicy
LinearPolicy = ARSLinearPolicy


register_policy("LinearPolicy", LinearPolicy)
register_policy("MlpPolicy", MlpPolicy)
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.4.1a1
1.4.1a3
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators
per-file-ignores =
./sb3_contrib/__init__.py:F401
./sb3_contrib/ars/__init__.py:F401
./sb3_contrib/cem/__init__.py:F401
./sb3_contrib/ppo_mask/__init__.py:F401
./sb3_contrib/qrdqn/__init__.py:F401
./sb3_contrib/tqc/__init__.py:F401
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.4.1a1",
"torch>=1.11",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
20 changes: 16 additions & 4 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize

from sb3_contrib import ARS, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib import ARS, CEM, QRDQN, TQC, TRPO, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.vec_env import AsyncEval

Expand Down Expand Up @@ -92,13 +92,25 @@ def test_ars(policy_str, env_id):
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)


def test_ars_multi_env():
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("policy_str", ["LinearPolicy", "MlpPolicy"])
def test_cem(policy_str, env_id):
model = CEM(policy_str, env_id, pop_size=2, verbose=1, seed=0)
model.learn(total_timesteps=500, log_interval=1, eval_freq=250)


@pytest.mark.parametrize("model_class", [ARS, CEM])
def test_es_multi_env(model_class):
env = make_vec_env("Pendulum-v1", n_envs=2)
model = ARS("MlpPolicy", env, n_delta=1)
kwargs = dict(n_delta=1) if model_class == ARS else dict(pop_size=2)

model = model_class("MlpPolicy", env, **kwargs)
model.learn(total_timesteps=250)

kwargs = dict(n_delta=2) if model_class == ARS else dict(pop_size=3)

env = VecNormalize(make_vec_env("Pendulum-v1", n_envs=1))
model = ARS("MlpPolicy", env, n_delta=2, seed=0)
model = model_class("MlpPolicy", env, seed=0, **kwargs)
# with parallelism
async_eval = AsyncEval([lambda: VecNormalize(make_vec_env("Pendulum-v1", n_envs=1)) for _ in range(2)], model.policy)
async_eval.seed(0)
Expand Down