Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Use the ``FloatSchedule`` and ``LinearSchedule`` classes instead of lambdas in the ARS, PPO, and QRDQN implementations to improve model portability across different operating systems
- Fixed a bug in the ``MaskableCategoricalDistribution`` and ``MaskableMultiCategoricalDistribution`` classes where the `apply_masking` method was not correctly handling the masks for multi-dimensional action spaces

Deprecations:
^^^^^^^^^^^^^
Expand Down
82 changes: 51 additions & 31 deletions sb3_contrib/common/maskable/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stable_baselines3.common.distributions import Distribution
from torch import nn
from torch.distributions import Categorical
from torch.distributions.utils import logits_to_probs
from torch.distributions.utils import probs_to_logits

SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution")
SelfMaskableMultiCategoricalDistribution = TypeVar(
Expand All @@ -16,6 +16,29 @@
MaybeMasks = Union[th.Tensor, np.ndarray, None]


def _mask_logits(logits: th.Tensor, mask: MaybeMasks, neg_inf: float) -> th.Tensor:
"""
Eliminate chosen categorical outcomes by setting their logits to `neg_inf`.

:param logits: A tensor of unnormalized log probabilities (logits) for the categorical distribution.
The shape should be compatible with the mask.

:param mask: An optional boolean ndarray of compatible shape with the distribution.
If True, the corresponding choice's logit value is preserved. If False, it is set
to a large negative value, resulting in near 0 probability. If mask is None, any
previously applied masking is removed, and the original logits are restored.

:param neg_inf: The value to use for masked logits, typically negative infinity
to ensure the masked actions have zero (or near-zero) probability when passed
through a softmax or categorical distribution.
"""

if mask is None:
return logits
mask_t = th.as_tensor(mask, dtype=th.bool, device=logits.device).reshape(logits.shape)
return th.where(mask_t, logits, th.tensor(neg_inf, dtype=logits.dtype, device=logits.device))


class MaskableCategorical(Categorical):
"""
Modified PyTorch Categorical distribution with support for invalid action masking.
Expand All @@ -39,48 +62,45 @@ def __init__(
validate_args: Optional[bool] = None,
masks: MaybeMasks = None,
):
self.masks: Optional[th.Tensor] = None
super().__init__(probs, logits, validate_args)
self._original_logits = self.logits
self.apply_masking(masks)

def apply_masking(self, masks: MaybeMasks) -> None:
"""
Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0.
# Validate that exactly one of probs or logits is provided
if (probs is None) == (logits is None):
raise ValueError("Specify exactly one of probs or logits but not both.")

:param masks: An optional boolean ndarray of compatible shape with the distribution.
If True, the corresponding choice's logit value is preserved. If False, it is set
to a large negative value, resulting in near 0 probability. If masks is None, any
previously applied masking is removed, and the original logits are restored.
"""
# If probs provided, convert it to logits
if logits is None:
logits = probs_to_logits(probs)

if masks is not None:
device = self.logits.device
self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device)
# Save pristine logits for later masking
self._original_logits = logits.detach().clone()
self._neg_inf = float("-inf")
self.masks = None if masks is None else th.as_tensor(masks, dtype=th.bool, device=logits.device).reshape(logits.shape)
masked_logits = _mask_logits(logits, self.masks, self._neg_inf)
super().__init__(logits=masked_logits, validate_args=validate_args)

logits = th.where(self.masks, self._original_logits, HUGE_NEG)
else:
def apply_masking(self, masks: MaybeMasks) -> None:
if masks is None:
self.masks = None
logits = self._original_logits

else:
self.masks = th.as_tensor(masks, dtype=th.bool, device=self._original_logits.device).reshape(
self._original_logits.shape
)
logits = _mask_logits(self._original_logits, self.masks, self._neg_inf)
# Reinitialize with updated logits
super().__init__(logits=logits)

# self.probs may already be cached, so we must force an update
self.probs = logits_to_probs(self.logits)

def entropy(self) -> th.Tensor:
if self.masks is None:
return super().entropy()

# Highly negative logits don't result in 0 probs, so we must replace
# with 0s to ensure 0 contribution to the distribution's entropy, since
# masked actions possess no uncertainty.
device = self.logits.device
p_log_p = self.logits * self.probs
p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device))
return -p_log_p.sum(-1)
# Prevent numerical issues with masked logits
min_real = th.finfo(self.logits.dtype).min
logits = self.logits.clone()
mask = (~self.masks) | (~logits.isfinite())
logits = logits.masked_fill(mask, min_real)
logits = logits - logits.logsumexp(-1, keepdim=True)
probs = logits.exp()
return -(logits * probs).sum(-1)


class MaskableDistribution(Distribution, ABC):
Expand Down