diff --git a/README.md b/README.md index f54ae854..5c3415ce 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) +- [Dueling DQN](https://arxiv.org/abs/1511.06581) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 234e6f8c..911a6e7e 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ ARS ✔️ ❌️ ❌ ❌ ✔️ +Dueling DQN ❌ ️✔️ ❌ ❌ ✔️ MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️ QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index cd4851d0..6e6a9358 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -16,6 +16,19 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment. model.learn(total_timesteps=10_000, log_interval=4) model.save("tqc_pendulum") +DuelingDQN +---------- + +Train a Dueling DQN agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import DuelingDQN + + model = DuelingDQN("MlpPolicy", "CartPole-v1", verbose=1) + model.learn(total_timesteps=10_000, log_interval=4) + model.save("dueling_dqn_cartpole") + QR-DQN ------ diff --git a/docs/index.rst b/docs/index.rst index 5e322652..fe931569 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/ars + modules/dueling_dqn modules/ppo_mask modules/ppo_recurrent modules/qrdqn diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 942c98ae..df53669b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking - Added ``with_bias`` parameter to ``ARSPolicy`` +- Added ``DuelingDQN`` Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/dueling_dqn.rst b/docs/modules/dueling_dqn.rst new file mode 100644 index 00000000..2cb426b1 --- /dev/null +++ b/docs/modules/dueling_dqn.rst @@ -0,0 +1,153 @@ +.. _dueling_dqn: + +.. automodule:: sb3_contrib.dueling_dqn + + +Dueling-DQN +=========== + +`Dueling DQN `_ builds on `Deep Q-Network (DQN) `_ +and #TODO: + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + MultiInputPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1511.06581 + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ❌ ✔️ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + + from sb3_contrib import DuelingDQN + + env = gym.make("CartPole-v1") + + model = DuelingDQN("MlpPolicy", env, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("dueling_dqn_cartpole") + + del model # remove to demonstrate saving and loading + + model = DuelingDQN.load("dueling_dqn_cartpole") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds. + +The complete learning curves are available in the `associated PR `_. #TODO: + + +.. note:: + + DuelingDQN implementation was validated against #TODO: valid the results + + +============ =========== =========== +Environments DuelingDQN DQN +============ =========== =========== +Breakout ~300 +Pong ~20 +CartPole 500 +/- 0 +MountainCar -107 +/- 4 +LunarLander 195 +/- 28 +Acrobot -74 +/- 2 +============ =========== =========== + +#TODO: Fill the tabular + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo fork and checkout the branch ``feat/dueling-dqn``: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo/ + cd rl-baselines3-zoo/ + git checkout feat/dueling-dqn #TODO: create this branch + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo dueling_dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 #TODO: check if that command line works + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a dueling_dqn -e Breakout Pong -f logs/ -o logs/dueling_dqn_results #TODO: check if that command line works + python scripts/plot_from_file.py -i logs/dueling_dqn_results.pkl -latex -l Dueling DQN #TODO: check if that command line works + + + +Parameters +---------- + +.. autoclass:: DuelingDQN + :members: + :inherited-members: + +.. _dueling_dqn_policies: + +Dueling DQN Policies +-------------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.dueling_dqn.policies.DuelingDQNPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: MultiInputPolicy + :members: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..7374d687 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,7 @@ import os from sb3_contrib.ars import ARS +from sb3_contrib.dueling_dqn import DuelingDQN from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN @@ -14,6 +15,7 @@ __all__ = [ "ARS", + "DuelingDQN", "MaskablePPO", "RecurrentPPO", "QRDQN", diff --git a/sb3_contrib/dueling_dqn/__init__.py b/sb3_contrib/dueling_dqn/__init__.py new file mode 100644 index 00000000..4243fae0 --- /dev/null +++ b/sb3_contrib/dueling_dqn/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN +from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/dueling_dqn/dueling_dqn.py b/sb3_contrib/dueling_dqn/dueling_dqn.py new file mode 100644 index 00000000..5e26659b --- /dev/null +++ b/sb3_contrib/dueling_dqn/dueling_dqn.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union + +import torch as th +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.dqn.dqn import DQN + +from sb3_contrib.dueling_dqn.policies import CnnPolicy, DuelingDQNPolicy, MlpPolicy, MultiInputPolicy + +SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN") + + +class DuelingDQN(DQN): + """ + Dueling Deep Q-Network (Dueling DQN) + + Paper: https://arxiv.org/abs/1511.06581 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + + def __init__( + self, + policy: Union[str, Type[DuelingDQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 0.0001, + buffer_size: int = 1000000, + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1, + exploration_final_eps: float = 0.05, + max_grad_norm: float = 10, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + replay_buffer_class, + replay_buffer_kwargs, + optimize_memory_usage, + target_update_interval, + exploration_fraction, + exploration_initial_eps, + exploration_final_eps, + max_grad_norm, + tensorboard_log, + policy_kwargs, + verbose, + seed, + device, + _init_setup_model, + ) + + def learn( + self: SelfDuelingDQN, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "DuelingDQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfDuelingDQN: + return super().learn( + total_timesteps, + callback, + log_interval, + tb_log_name, + reset_num_timesteps, + progress_bar, + ) diff --git a/sb3_contrib/dueling_dqn/policies.py b/sb3_contrib/dueling_dqn/policies.py new file mode 100644 index 00000000..1796f4ee --- /dev/null +++ b/sb3_contrib/dueling_dqn/policies.py @@ -0,0 +1,190 @@ +from typing import Any, Dict, List, Optional, Type + +import gym +import torch as th +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, CombinedExtractor, NatureCNN, create_mlp +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.dqn.policies import DQNPolicy, QNetwork +from torch import nn + + +class Dueling(nn.Module): + """ + Dueling submodule. + + :param value_stream: Value stream + :param advantage_stream: Advantage stream + """ + + def __init__(self, value_stream: nn.Module, advantage_stream: nn.Module) -> None: + super().__init__() + self.value_stream = value_stream + self.advantage_stream = advantage_stream + + def forward(self, features: th.Tensor) -> th.Tensor: + values = self.value_stream(features) + advantages = self.advantage_stream(features) + return values + (advantages - advantages.mean()) # TODO: check if dim is needed in mean() + + +class DuelingQNetwork(QNetwork): + """ + Dueling Q-Network. + + :param observation_space: Observation space + :param action_space: Action space + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor: nn.Module, + features_dim: int, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor, + features_dim, + net_arch, + activation_fn, + normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + action_dim = self.action_space.n # number of actions + value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn) + advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) + # self.q_net is a Sequential in DQN, and is a Dueling here, and thus raises a mypy error. + # Since it would take a lot of effort to make it mypy compliant, and this implementation + # is temporary (will be special case of Rainbow in the future) we ignore the error. + self.q_net = Dueling(nn.Sequential(*value_stream), nn.Sequential(*advantage_stream)) # type: ignore[assignment] + + +class DuelingDQNPolicy(DQNPolicy): + """ + Policy class for Dueling DQN. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def make_q_net(self) -> DuelingQNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return DuelingQNetwork(**net_args).to(self.device) + + +MlpPolicy = DuelingDQNPolicy + + +class CnnPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) diff --git a/setup.py b/setup.py index 60fe1d33..89c7a8a1 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) +- [Dueling DQN](https://arxiv.org/abs/1511.06581) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e570aab6..550c9dee 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,10 +8,10 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC, TRPO +from sb3_contrib import QRDQN, TQC, TRPO, DuelingDQN -@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, DuelingDQN]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip