Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ New Features:
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
- Added ``NStepReplayBuffer`` that allows to compute n-step returns without additional memory requirement (and without for loops)
- Added Gymnasium v1.2 support
- Added Prioritized Approximation Loss

Bug Fixes:
^^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ Notes
- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial

.. note::
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN
Prioritized Experience Replay can be use by passing ``PrioritizedReplayBuffer`` via the ``replay_buffer_class`` argument.


Can I use?
Expand Down
25 changes: 25 additions & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,28 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
rewards=self.to_torch(n_step_returns),
discounts=self.to_torch(target_q_discounts),
)


class PrioritizedReplayBuffer(ReplayBuffer):
"""
Prioritized Experience Replay Buffer.
This buffer is the same as the ReplayBuffer but when it is selected, pal_loss is used.
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
alpha: float = 0.6,
beta: float = 0.4,
):
super().__init__(
buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage, handle_timeout_termination
)
self.alpha = alpha
self.beta = beta
43 changes: 43 additions & 0 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch as th
from gymnasium import spaces
from torch import Tensor

import stable_baselines3 as sb3

Expand Down Expand Up @@ -638,3 +639,45 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
if print_info:
print(env_info_str)
return env_info, env_info_str


def pal_loss(input_: Tensor, target_: Tensor, reduction: str = "mean", alpha: float = 0.6, beta: float = 0.4) -> Tensor:
"""
Prioritized Approximation Loss
Ref: An Equivalence between Loss Functions and Non-Uniform Sampling in Experience Replay - Scott Fujimoto and al. (Neurips 2020)

:param input_: (Tensor) Predicted values.
:param target_: (Tensor) Ground truth values.
:param reduction: (str) Specifies the reduction to apply to the output 'none' | 'mean' | 'sum' - default: 'mean'
:param alpha: (float) - default=0.6
:param beta: (float) - default=0.4

:return loss: (Tensor) Prioritized Approximation Loss.
"""

if not (target_.size() == input_.size()):
warnings.warn(
f"Using a target size ({target_.size()}) that is different to the input size ({input_.size()}). "
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size.",
stacklevel=2,
)

# compute hyperparameters
delta = target_ - input_
abs_delta = delta.abs().detach()
tau = th.where(abs_delta <= 1.0, th.tensor(2.0), th.tensor(1.0)).to(delta.device) # for huber loss
power = tau + alpha * (1 - beta)

eta = abs_delta.pow(alpha * beta).min() / (abs_delta.pow(alpha).sum() + 1e-8) # avoid zero division
N = abs_delta.shape[0]
loss = th.tensor((eta * N / power) * (abs_delta.pow(power)))

if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
elif reduction == "sum":
return th.sum(loss)
else:
raise ValueError(f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'.")
10 changes: 7 additions & 3 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.buffers import ReplayBuffer, PrioritizedReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update
from stable_baselines3.common.utils import LinearSchedule, get_parameters_by_name, polyak_update, pal_loss
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

SelfDQN = TypeVar("SelfDQN", bound="DQN")
Expand Down Expand Up @@ -214,7 +214,11 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
loss = (
F.smooth_l1_loss(current_q_values, target_q_values)
if not isinstance(self.replay_buffer_class, PrioritizedReplayBuffer)
else pal_loss(current_q_values, target_q_values, alpha=self.replay_buffer.alpha, beta=self.replay_buffer.beta)
)
losses.append(loss.item())

# Optimize the policy
Expand Down
Empty file added test.py
Empty file.