From 61a24de7d3d2b8f7f6d5100e94384683cb51c4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:02:17 +0100 Subject: [PATCH 01/15] DuelingDQN --- sb3_contrib/__init__.py | 2 + sb3_contrib/dueling_dqn/__init__.py | 4 + sb3_contrib/dueling_dqn/dueling_dqn.py | 125 +++++++++++++++++ sb3_contrib/dueling_dqn/policies.py | 182 +++++++++++++++++++++++++ 4 files changed, 313 insertions(+) create mode 100644 sb3_contrib/dueling_dqn/__init__.py create mode 100644 sb3_contrib/dueling_dqn/dueling_dqn.py create mode 100644 sb3_contrib/dueling_dqn/policies.py diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..c4e02f64 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,7 @@ import os from sb3_contrib.ars import ARS +from sb3_contrib.dueling_dqn import DuelingDQN from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN @@ -14,6 +15,7 @@ __all__ = [ "ARS", + "DuelingDQN" "MaskablePPO", "RecurrentPPO", "QRDQN", diff --git a/sb3_contrib/dueling_dqn/__init__.py b/sb3_contrib/dueling_dqn/__init__.py new file mode 100644 index 00000000..4243fae0 --- /dev/null +++ b/sb3_contrib/dueling_dqn/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN +from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/dueling_dqn/dueling_dqn.py b/sb3_contrib/dueling_dqn/dueling_dqn.py new file mode 100644 index 00000000..1e4c9f8b --- /dev/null +++ b/sb3_contrib/dueling_dqn/dueling_dqn.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union + +import torch as th +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.dqn.dqn import DQN + +from sb3_contrib.dueling_dqn.policies import DuelingDQNPolicy + +SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN") + + +class DuelingDQN(DQN): + """ + Dueling Deep Q-Network (Dueling DQN) + + Paper: https://arxiv.org/abs/1511.06581 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[DuelingDQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 0.0001, + buffer_size: int = 1000000, + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1, + exploration_final_eps: float = 0.05, + max_grad_norm: float = 10, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + replay_buffer_class, + replay_buffer_kwargs, + optimize_memory_usage, + target_update_interval, + exploration_fraction, + exploration_initial_eps, + exploration_final_eps, + max_grad_norm, + tensorboard_log, + policy_kwargs, + verbose, + seed, + device, + _init_setup_model, + ) + + def learn( + self: SelfDuelingDQN, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "DuelingDQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfDuelingDQN: + return super().learn( + total_timesteps, + callback, + log_interval, + tb_log_name, + reset_num_timesteps, + progress_bar, + ) diff --git a/sb3_contrib/dueling_dqn/policies.py b/sb3_contrib/dueling_dqn/policies.py new file mode 100644 index 00000000..6f87b56f --- /dev/null +++ b/sb3_contrib/dueling_dqn/policies.py @@ -0,0 +1,182 @@ +from typing import Any, Dict, List, Optional, Type + +import gym +import torch as th +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, CombinedExtractor, NatureCNN, create_mlp +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.dqn.policies import DQNPolicy, QNetwork +from torch import nn + + +class DuelingQNetwork(QNetwork): + """ + Dueling Q-Network. + + :param observation_space: Observation space + :param action_space: Action space + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor: nn.Module, + features_dim: int, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor, + features_dim, + net_arch, + activation_fn, + normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + action_dim = self.action_space.n # number of actions + value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn) + self.value_stream = nn.Sequential(*value_stream) + advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) + self.advantage_stream = nn.Sequential(*advantage_stream) + + def forward(self, obs: th.Tensor) -> th.Tensor: + """ + Predict the q-values. + + :param obs: Observation + :return: The estimated Q-Value for each action. + """ + features = self.extract_features(obs) + values = self.value_stream(features) + advantages = self.advantage_stream(features) + qvals = values + (advantages - advantages.mean()) + return qvals + + +class DuelingDQNPolicy(DQNPolicy): + """ + Policy class for Dueling DQN. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def make_q_net(self) -> DuelingQNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return DuelingQNetwork(**net_args).to(self.device) + + +MlpPolicy = DuelingDQNPolicy + + +class CnnPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) From 888a06b854e7e3d231da1ff7c26d6d537764c811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:36:12 +0100 Subject: [PATCH 02/15] dueling_dqn.rst --- docs/modules/dueling_dqn.rst | 154 +++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 docs/modules/dueling_dqn.rst diff --git a/docs/modules/dueling_dqn.rst b/docs/modules/dueling_dqn.rst new file mode 100644 index 00000000..68494de5 --- /dev/null +++ b/docs/modules/dueling_dqn.rst @@ -0,0 +1,154 @@ +.. _dueling_dqn: + +.. automodule:: sb3_contrib.dueling_dqn + + +Dueling-DQN +=========== + +`Dueling DQN `_ builds on `Deep Q-Network (DQN) `_ +and #TODO: + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + MultiInputPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1511.06581 + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ❌ ✔️ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + + from sb3_contrib import DuelingDQN + + env = gym.make("CartPole-v1") + + policy_kwargs = dict(n_quantiles=50) + model = DuelingDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("dueling_dqn_cartpole") + + del model # remove to demonstrate saving and loading + + model = DuelingDQN.load("dueling_dqn_cartpole") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds. + +The complete learning curves are available in the `associated PR `_. #TODO: + + +.. note:: + + DuelingDQN implementation was validated against #TODO: valid the results + + +============ =========== =========== +Environments DuelingDQN DQN +============ =========== =========== +Breakout ~300 +Pong ~20 +CartPole 500 +/- 0 +MountainCar -107 +/- 4 +LunarLander 195 +/- 28 +Acrobot -74 +/- 2 +============ =========== =========== + +#TODO: Fill the tabular + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo fork and checkout the branch ``feat/dueling-dqn``: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo/ + cd rl-baselines3-zoo/ + git checkout feat/dueling-dqn #TODO: create this branch + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo dueling_dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 #TODO: check if that command line works + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a dueling_dqn -e Breakout Pong -f logs/ -o logs/dueling_dqn_results #TODO: check if that command line works + python scripts/plot_from_file.py -i logs/dueling_dqn_results.pkl -latex -l Dueling DQN #TODO: check if that command line works + + + +Parameters +---------- + +.. autoclass:: DuelingDQN + :members: + :inherited-members: + +.. _dueling_dqn_policies: + +Dueling DQN Policies +-------------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.dueling_dqn.policies.DuelingDQNPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: MultiInputPolicy + :members: From 0d19cfeca7c0bb8b682c95bacc554b0f8f2cffb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:36:29 +0100 Subject: [PATCH 03/15] Update changelog --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2bb377b7..55792306 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking +- Added ``DuelingDQN`` Bug Fixes: ^^^^^^^^^^ From 0a303124d7db1a2eb95c10994acb5f452a51e563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:37:20 +0100 Subject: [PATCH 04/15] Add example in example.rst --- docs/guide/examples.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index cd4851d0..a47edd86 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -16,6 +16,20 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment. model.learn(total_timesteps=10_000, log_interval=4) model.save("tqc_pendulum") +DuelingDQN +---------- + +Train a Dueling DQN agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import DuelingDQN + + policy_kwargs = dict(n_quantiles=50) + model = DuelingDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10_000, log_interval=4) + model.save("dueling_dqn_cartpole") + QR-DQN ------ From 726d4b94f2f26534b0c122ae004afcea6af6cc7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:38:05 +0100 Subject: [PATCH 05/15] add dueling to index.rst --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index 5e322652..fe931569 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/ars + modules/dueling_dqn modules/ppo_mask modules/ppo_recurrent modules/qrdqn From 400e63696ff48b204b1327e8ee34481d51dd9501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:41:52 +0100 Subject: [PATCH 06/15] Add policy_aliases --- sb3_contrib/dueling_dqn/dueling_dqn.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/dueling_dqn/dueling_dqn.py b/sb3_contrib/dueling_dqn/dueling_dqn.py index 1e4c9f8b..5e26659b 100644 --- a/sb3_contrib/dueling_dqn/dueling_dqn.py +++ b/sb3_contrib/dueling_dqn/dueling_dqn.py @@ -2,10 +2,11 @@ import torch as th from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.dqn.dqn import DQN -from sb3_contrib.dueling_dqn.policies import DuelingDQNPolicy +from sb3_contrib.dueling_dqn.policies import CnnPolicy, DuelingDQNPolicy, MlpPolicy, MultiInputPolicy SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN") @@ -52,6 +53,12 @@ class DuelingDQN(DQN): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[DuelingDQNPolicy]], From b2ee62932ed52dd9724fc563928c9a2097b69419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:47:37 +0100 Subject: [PATCH 07/15] test-cnn --- tests/test_cnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e570aab6..550c9dee 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,10 +8,10 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import QRDQN, TQC, TRPO, DuelingDQN -@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, DuelingDQN]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip From 823760b5d0e82db56f4ef1dd121b08a5e8344bf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 11:51:55 +0100 Subject: [PATCH 08/15] typo --- sb3_contrib/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index c4e02f64..4a4e8ae4 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -15,8 +15,7 @@ __all__ = [ "ARS", - "DuelingDQN" - "MaskablePPO", + "DuelingDQN" "MaskablePPO", "RecurrentPPO", "QRDQN", "TQC", From bf44d999ec16063b3df53c32e57315af49ba84d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:12:27 +0100 Subject: [PATCH 09/15] simplification --- sb3_contrib/dueling_dqn/policies.py | 35 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sb3_contrib/dueling_dqn/policies.py b/sb3_contrib/dueling_dqn/policies.py index 6f87b56f..61a77723 100644 --- a/sb3_contrib/dueling_dqn/policies.py +++ b/sb3_contrib/dueling_dqn/policies.py @@ -8,6 +8,25 @@ from torch import nn +class Dueling(nn.Module): + """ + Dueling submodule. + + :param value_stream: Value stream + :param advantage_stream: Advantage stream + """ + + def __init__(self, value_stream: nn.Module, advantage_stream: nn.Module) -> None: + super().__init__() + self.value_stream = value_stream + self.advantage_stream = advantage_stream + + def forward(self, features: th.Tensor) -> th.Tensor: + values = self.value_stream(features) + advantages = self.advantage_stream(features) + return values + (advantages - advantages.mean()) # TODO: check if dim is needed in mean() + + class DuelingQNetwork(QNetwork): """ Dueling Q-Network. @@ -45,22 +64,8 @@ def __init__( action_dim = self.action_space.n # number of actions value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn) - self.value_stream = nn.Sequential(*value_stream) advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) - self.advantage_stream = nn.Sequential(*advantage_stream) - - def forward(self, obs: th.Tensor) -> th.Tensor: - """ - Predict the q-values. - - :param obs: Observation - :return: The estimated Q-Value for each action. - """ - features = self.extract_features(obs) - values = self.value_stream(features) - advantages = self.advantage_stream(features) - qvals = values + (advantages - advantages.mean()) - return qvals + self.q_net = Dueling(nn.Sequential(*value_stream), nn.Sequential(*advantage_stream)) class DuelingDQNPolicy(DQNPolicy): From 372da3e8e6252acab27595f03846869be8eb301e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:12:43 +0100 Subject: [PATCH 10/15] typo --- sb3_contrib/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 4a4e8ae4..7374d687 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -15,7 +15,8 @@ __all__ = [ "ARS", - "DuelingDQN" "MaskablePPO", + "DuelingDQN", + "MaskablePPO", "RecurrentPPO", "QRDQN", "TQC", From d93f00b45e8ee128e4b357cabf6933718c6e8e85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:16:46 +0100 Subject: [PATCH 11/15] Rm policy_kwargs as error from copying from DRDQN --- docs/guide/examples.rst | 3 +-- docs/modules/dueling_dqn.rst | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a47edd86..6e6a9358 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -25,8 +25,7 @@ Train a Dueling DQN agent on the CartPole environment. from sb3_contrib import DuelingDQN - policy_kwargs = dict(n_quantiles=50) - model = DuelingDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model = DuelingDQN("MlpPolicy", "CartPole-v1", verbose=1) model.learn(total_timesteps=10_000, log_interval=4) model.save("dueling_dqn_cartpole") diff --git a/docs/modules/dueling_dqn.rst b/docs/modules/dueling_dqn.rst index 68494de5..2cb426b1 100644 --- a/docs/modules/dueling_dqn.rst +++ b/docs/modules/dueling_dqn.rst @@ -56,8 +56,7 @@ Example env = gym.make("CartPole-v1") - policy_kwargs = dict(n_quantiles=50) - model = DuelingDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) + model = DuelingDQN("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("dueling_dqn_cartpole") @@ -71,7 +70,7 @@ Example obs, reward, done, info = env.step(action) env.render() if done: - obs = env.reset() + obs = env.reset() Results From bd8755a4730d38df17dd2d2d0ec7a05b333ec984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:26:04 +0100 Subject: [PATCH 12/15] Update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f54ae854..5c3415ce 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) +- [Dueling DQN](https://arxiv.org/abs/1511.06581) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/) From 9b2366318285c57dfd60329414a37ec892dd8c91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:26:15 +0100 Subject: [PATCH 13/15] Update setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 14214f7c..07ab337f 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) +- [Dueling DQN](https://arxiv.org/abs/1511.06581) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) From 7a34a779b024f49260678b4f611d01b842b5c744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:26:33 +0100 Subject: [PATCH 14/15] Add dueling to the list of algorithm --- docs/guide/algos.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 234e6f8c..911a6e7e 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ ARS ✔️ ❌️ ❌ ❌ ✔️ +Dueling DQN ❌ ️✔️ ❌ ❌ ✔️ MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️ QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️ From 989f31fc5964c1ec0cd6783f1fafa575250bb5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 10 Dec 2022 12:34:11 +0100 Subject: [PATCH 15/15] ignore mypy error --- sb3_contrib/dueling_dqn/policies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/dueling_dqn/policies.py b/sb3_contrib/dueling_dqn/policies.py index 61a77723..1796f4ee 100644 --- a/sb3_contrib/dueling_dqn/policies.py +++ b/sb3_contrib/dueling_dqn/policies.py @@ -65,7 +65,10 @@ def __init__( action_dim = self.action_space.n # number of actions value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn) advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) - self.q_net = Dueling(nn.Sequential(*value_stream), nn.Sequential(*advantage_stream)) + # self.q_net is a Sequential in DQN, and is a Dueling here, and thus raises a mypy error. + # Since it would take a lot of effort to make it mypy compliant, and this implementation + # is temporary (will be special case of Rainbow in the future) we ignore the error. + self.q_net = Dueling(nn.Sequential(*value_stream), nn.Sequential(*advantage_stream)) # type: ignore[assignment] class DuelingDQNPolicy(DQNPolicy):