From 502d7e0df77f82be4ffb73317d7ca3a2db37b03b Mon Sep 17 00:00:00 2001 From: bilelsag Date: Tue, 5 Aug 2025 15:24:51 +0200 Subject: [PATCH 1/5] Add Prioritized Approximation loss feature --- docs/modules/dqn.rst | 3 +- stable_baselines3/common/buffers.py | 23 +++++++++++++++ stable_baselines3/common/utils.py | 43 +++++++++++++++++++++++++++++ stable_baselines3/dqn/dqn.py | 10 +++++-- 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 78f70f6982..8e49c60443 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -28,7 +28,8 @@ Notes - Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial .. note:: - This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. + This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN + Prioritized Experience Replay can be use by passing ``PrioritizedReplayBuffer`` via the ``replay_buffer_class`` argument. Can I use? diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index af7308f5bf..dc61c50e26 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -948,3 +948,26 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non rewards=self.to_torch(n_step_returns), discounts=self.to_torch(target_q_discounts), ) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Experience Replay Buffer. + The buffer is the same as ReplayBuffer but when this replay class is selected, pal_loss is used. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, + alpha: float = 0.6, + beta: float = 0.4, + ): + super().__init__( + buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage, handle_timeout_termination + ) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 9509326595..d2f2256ebd 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -14,6 +14,7 @@ import numpy as np import torch as th from gymnasium import spaces +from torch import Tensor import stable_baselines3 as sb3 @@ -638,3 +639,45 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def pal_loss(input_: Tensor, target_: Tensor, reduction: str = "mean", alpha: float = 0.6, beta: float = 0.4) -> Tensor: + """ + Prioritized Approximation Loss + Ref: An Equivalence between Loss Functions and Non-Uniform Sampling in Experience Replay - Scott Fujimoto and al. (Neurips 2020) + + :param input_: (Tensor) Predicted values. + :param target_: (Tensor) Ground truth values. + :param reduction: (str) Specifies the reduction to apply to the output 'none' | 'mean' | 'sum' - default: 'mean' + :param alpha: (float) - default=0.6 + :param beta: (float) - default=0.4 + + :return loss: (Tensor) Prioritized Approximation Loss. + """ + + if not (target_.size() == input_.size()): + warnings.warn( + f"Using a target size ({target_.size()}) that is different to the input size ({input_.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + + # compute hyperparameters + delta = target_ - input_ + abs_delta = delta.abs().detach() + tau = th.where(abs_delta <= 1.0, th.tensor(2.0), th.tensor(1.0)).to(delta.device) # for huber loss + power = tau + alpha * (1 - beta) + + eta = abs_delta.pow(alpha * beta).min() / (abs_delta.pow(alpha).sum() + 1e-8) # avoid zero division + N = abs_delta.shape[0] + loss = th.tensor((eta * N / power) * (abs_delta.pow(power))) + + if reduction == "none": + return loss + elif reduction == "mean": + return loss.mean() + elif reduction == "sum": + return th.sum(loss) + else: + raise ValueError(f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'.") diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index f19f2a9d18..cddc8f5042 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -6,11 +6,11 @@ from gymnasium import spaces from torch.nn import functional as F -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import ReplayBuffer, PrioritizedReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update +from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update, pal_loss from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") @@ -214,7 +214,11 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q_values) + loss = ( + F.smooth_l1_loss(current_q_values, target_q_values) + if not isinstance(self.replay_buffer_class, PrioritizedReplayBuffer) + else pal_loss(current_q_values, target_q_values) + ) losses.append(loss.item()) # Optimize the policy From 53a7b2d99f2cf4d86ffdacf2461d9fde954f3147 Mon Sep 17 00:00:00 2001 From: bilelsag Date: Tue, 5 Aug 2025 15:58:43 +0200 Subject: [PATCH 2/5] Add changelog --- docs/misc/changelog.rst | 1 + test.py | 0 2 files changed, 1 insertion(+) create mode 100644 test.py diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9f4cb7dc05..16378067cd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ New Features: - Added support for n-step returns for off-policy algorithms via the `n_steps` parameter - Added ``NStepReplayBuffer`` that allows to compute n-step returns without additional memory requirement (and without for loops) - Added Gymnasium v1.2 support +- Added Prioritized Approximation Loss Bug Fixes: ^^^^^^^^^^ diff --git a/test.py b/test.py new file mode 100644 index 0000000000..e69de29bb2 From 2fa2d87970fc11aac72863148cfd26dce0f073af Mon Sep 17 00:00:00 2001 From: Bilel <61874108+bilelsgh@users.noreply.github.com> Date: Tue, 5 Aug 2025 16:53:47 +0200 Subject: [PATCH 3/5] Update buffers.py Doctring imrovement --- stable_baselines3/common/buffers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index dc61c50e26..32d3be9fbe 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -953,7 +953,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non class PrioritizedReplayBuffer(ReplayBuffer): """ Prioritized Experience Replay Buffer. - The buffer is the same as ReplayBuffer but when this replay class is selected, pal_loss is used. + This buffer is the same as the ReplayBuffer but when it is selected, pal_loss is used. """ def __init__( @@ -966,7 +966,7 @@ def __init__( optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, alpha: float = 0.6, - beta: float = 0.4, + beta: float = 0.4, ): super().__init__( buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage, handle_timeout_termination From dc99748fbf87eab594f186b1fded065c052fa30e Mon Sep 17 00:00:00 2001 From: Bilel <61874108+bilelsgh@users.noreply.github.com> Date: Tue, 5 Aug 2025 16:55:42 +0200 Subject: [PATCH 4/5] Add alpha and beta parameters --- stable_baselines3/dqn/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index cddc8f5042..f6bc2af573 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -217,7 +217,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: loss = ( F.smooth_l1_loss(current_q_values, target_q_values) if not isinstance(self.replay_buffer_class, PrioritizedReplayBuffer) - else pal_loss(current_q_values, target_q_values) + else pal_loss(current_q_values, target_q_values, alpha=self.replay_buffer.alpha, beta=self.replay_buffer.beta) ) losses.append(loss.item()) From d7c1bfaa4d126ea86ed6a56cbeda16ef615ee71f Mon Sep 17 00:00:00 2001 From: Bilel <61874108+bilelsgh@users.noreply.github.com> Date: Tue, 5 Aug 2025 16:59:30 +0200 Subject: [PATCH 5/5] alpha and beta attributes in PrioritizedReplayBuffer --- stable_baselines3/common/buffers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 18c8df0d38..e909fb0033 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -975,3 +975,5 @@ def __init__( super().__init__( buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage, handle_timeout_termination ) + self.alpha = alpha + self.beta = beta