diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d1e11c43..15f37052 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Use the ``FloatSchedule`` and ``LinearSchedule`` classes instead of lambdas in the ARS, PPO, and QRDQN implementations to improve model portability across different operating systems +- Fixed a bug in the ``MaskableCategoricalDistribution`` and ``MaskableMultiCategoricalDistribution`` classes where the `apply_masking` method was not correctly handling the masks for multi-dimensional action spaces Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index d2a92ae0..3e59ee3a 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -7,7 +7,7 @@ from stable_baselines3.common.distributions import Distribution from torch import nn from torch.distributions import Categorical -from torch.distributions.utils import logits_to_probs +from torch.distributions.utils import probs_to_logits SelfMaskableCategoricalDistribution = TypeVar("SelfMaskableCategoricalDistribution", bound="MaskableCategoricalDistribution") SelfMaskableMultiCategoricalDistribution = TypeVar( @@ -16,6 +16,29 @@ MaybeMasks = Union[th.Tensor, np.ndarray, None] +def _mask_logits(logits: th.Tensor, mask: MaybeMasks, neg_inf: float) -> th.Tensor: + """ + Eliminate chosen categorical outcomes by setting their logits to `neg_inf`. + + :param logits: A tensor of unnormalized log probabilities (logits) for the categorical distribution. + The shape should be compatible with the mask. + + :param mask: An optional boolean ndarray of compatible shape with the distribution. + If True, the corresponding choice's logit value is preserved. If False, it is set + to a large negative value, resulting in near 0 probability. If mask is None, any + previously applied masking is removed, and the original logits are restored. + + :param neg_inf: The value to use for masked logits, typically negative infinity + to ensure the masked actions have zero (or near-zero) probability when passed + through a softmax or categorical distribution. + """ + + if mask is None: + return logits + mask_t = th.as_tensor(mask, dtype=th.bool, device=logits.device).reshape(logits.shape) + return th.where(mask_t, logits, th.tensor(neg_inf, dtype=logits.dtype, device=logits.device)) + + class MaskableCategorical(Categorical): """ Modified PyTorch Categorical distribution with support for invalid action masking. @@ -39,48 +62,45 @@ def __init__( validate_args: Optional[bool] = None, masks: MaybeMasks = None, ): - self.masks: Optional[th.Tensor] = None - super().__init__(probs, logits, validate_args) - self._original_logits = self.logits - self.apply_masking(masks) - - def apply_masking(self, masks: MaybeMasks) -> None: - """ - Eliminate ("mask out") chosen categorical outcomes by setting their probability to 0. + # Validate that exactly one of probs or logits is provided + if (probs is None) == (logits is None): + raise ValueError("Specify exactly one of probs or logits but not both.") - :param masks: An optional boolean ndarray of compatible shape with the distribution. - If True, the corresponding choice's logit value is preserved. If False, it is set - to a large negative value, resulting in near 0 probability. If masks is None, any - previously applied masking is removed, and the original logits are restored. - """ + # If probs provided, convert it to logits + if logits is None: + logits = probs_to_logits(probs) - if masks is not None: - device = self.logits.device - self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape) - HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device) + # Save pristine logits for later masking + self._original_logits = logits.detach().clone() + self._neg_inf = float("-inf") + self.masks = None if masks is None else th.as_tensor(masks, dtype=th.bool, device=logits.device).reshape(logits.shape) + masked_logits = _mask_logits(logits, self.masks, self._neg_inf) + super().__init__(logits=masked_logits, validate_args=validate_args) - logits = th.where(self.masks, self._original_logits, HUGE_NEG) - else: + def apply_masking(self, masks: MaybeMasks) -> None: + if masks is None: self.masks = None logits = self._original_logits - + else: + self.masks = th.as_tensor(masks, dtype=th.bool, device=self._original_logits.device).reshape( + self._original_logits.shape + ) + logits = _mask_logits(self._original_logits, self.masks, self._neg_inf) # Reinitialize with updated logits super().__init__(logits=logits) - # self.probs may already be cached, so we must force an update - self.probs = logits_to_probs(self.logits) - def entropy(self) -> th.Tensor: if self.masks is None: return super().entropy() - # Highly negative logits don't result in 0 probs, so we must replace - # with 0s to ensure 0 contribution to the distribution's entropy, since - # masked actions possess no uncertainty. - device = self.logits.device - p_log_p = self.logits * self.probs - p_log_p = th.where(self.masks, p_log_p, th.tensor(0.0, device=device)) - return -p_log_p.sum(-1) + # Prevent numerical issues with masked logits + min_real = th.finfo(self.logits.dtype).min + logits = self.logits.clone() + mask = (~self.masks) | (~logits.isfinite()) + logits = logits.masked_fill(mask, min_real) + logits = logits - logits.logsumexp(-1, keepdim=True) + probs = logits.exp() + return -(logits * probs).sum(-1) class MaskableDistribution(Distribution, ABC):