diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fb6b9f281..3c0a72aec 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -62,6 +62,7 @@ New Features: - Added support for n-step returns for off-policy algorithms via the `n_steps` parameter - Added ``NStepReplayBuffer`` that allows to compute n-step returns without additional memory requirement (and without for loops) - Added Gymnasium v1.2 support +- Added Prioritized Approximation Loss Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 78f70f698..8e49c6044 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -28,7 +28,8 @@ Notes - Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial .. note:: - This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. + This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN + Prioritized Experience Replay can be use by passing ``PrioritizedReplayBuffer`` via the ``replay_buffer_class`` argument. Can I use? diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 773dc7f26..e909fb003 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -952,3 +952,28 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non rewards=self.to_torch(n_step_returns), discounts=self.to_torch(target_q_discounts), ) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Experience Replay Buffer. + This buffer is the same as the ReplayBuffer but when it is selected, pal_loss is used. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, + alpha: float = 0.6, + beta: float = 0.4, + ): + super().__init__( + buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage, handle_timeout_termination + ) + self.alpha = alpha + self.beta = beta diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 5d1789969..c9811e52b 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -14,6 +14,7 @@ import numpy as np import torch as th from gymnasium import spaces +from torch import Tensor import stable_baselines3 as sb3 @@ -638,3 +639,45 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def pal_loss(input_: Tensor, target_: Tensor, reduction: str = "mean", alpha: float = 0.6, beta: float = 0.4) -> Tensor: + """ + Prioritized Approximation Loss + Ref: An Equivalence between Loss Functions and Non-Uniform Sampling in Experience Replay - Scott Fujimoto and al. (Neurips 2020) + + :param input_: (Tensor) Predicted values. + :param target_: (Tensor) Ground truth values. + :param reduction: (str) Specifies the reduction to apply to the output 'none' | 'mean' | 'sum' - default: 'mean' + :param alpha: (float) - default=0.6 + :param beta: (float) - default=0.4 + + :return loss: (Tensor) Prioritized Approximation Loss. + """ + + if not (target_.size() == input_.size()): + warnings.warn( + f"Using a target size ({target_.size()}) that is different to the input size ({input_.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + + # compute hyperparameters + delta = target_ - input_ + abs_delta = delta.abs().detach() + tau = th.where(abs_delta <= 1.0, th.tensor(2.0), th.tensor(1.0)).to(delta.device) # for huber loss + power = tau + alpha * (1 - beta) + + eta = abs_delta.pow(alpha * beta).min() / (abs_delta.pow(alpha).sum() + 1e-8) # avoid zero division + N = abs_delta.shape[0] + loss = th.tensor((eta * N / power) * (abs_delta.pow(power))) + + if reduction == "none": + return loss + elif reduction == "mean": + return loss.mean() + elif reduction == "sum": + return th.sum(loss) + else: + raise ValueError(f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'.") diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index f19f2a9d1..f6bc2af57 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -6,11 +6,11 @@ from gymnasium import spaces from torch.nn import functional as F -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import ReplayBuffer, PrioritizedReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update +from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update, pal_loss from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") @@ -214,7 +214,11 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q_values) + loss = ( + F.smooth_l1_loss(current_q_values, target_q_values) + if not isinstance(self.replay_buffer_class, PrioritizedReplayBuffer) + else pal_loss(current_q_values, target_q_values, alpha=self.replay_buffer.alpha, beta=self.replay_buffer.beta) + ) losses.append(loss.item()) # Optimize the policy diff --git a/test.py b/test.py new file mode 100644 index 000000000..e69de29bb