-
Notifications
You must be signed in to change notification settings - Fork 229
Description
🐛 Bug
MaskableCategorical.apply_masking reinitializes Categorical and triggers validate_args on cached probs -> ValueError(Simplex) (torch 2.9)
Bug description
sb3_contrib.common.maskable.distributions.MaskableCategorical.apply_masking() calls torch.distributions.Categorical.__init__ again on the same distribution instance:
# sb3_contrib/common/maskable/distributions.py
super().__init__(logits=logits)With torch.distributions.Distribution._validate_args = True (default in my env), this can raise:
ValueError: Expected parameter probs (...) to satisfy the constraint Simplex(), but found invalid values
This happens when:
MaskableCategorical.probsis accessed (common in deterministic action selection /.mode()etc), which caches theprobstensor on the instance- later
apply_masking()is called again on the same instance (common: action mask changes every step) - during the re-init, torch validates the cached
probsagainstconstraints.Simplex()and fails when softmax sum drift is > 1e-6 in float32 (with many categories, e.g. 992)
Note: vanilla torch.distributions.Categorical(logits=...) does not validate Simplex in this path (logits are validated as real vectors), so it does not crash.
Minimal reproduction (no SB3 model / no env needed)
This reproducer only needs torch + sb3_contrib.
import torch as th
from torch.distributions import Categorical
from sb3_contrib.common.maskable.distributions import MaskableCategorical
print("torch Distribution._validate_args:", th.distributions.Distribution._validate_args)
n = 992
# Construct logits that produce a softmax sum error > 1e-6 in float32
# (enough for torch.constraints.simplex.check() to fail in torch 2.9)
delta = 17
logits = th.full((1, n), -delta, dtype=th.float32)
logits[0, 0] = 0.0
# Sanity: plain torch Categorical is OK
p = Categorical(logits=logits).probs
print("Categorical probs sum:", float(p.sum()), "err:", float((p.sum(dim=-1) - 1).abs().max()))
# Create MaskableCategorical
mc = MaskableCategorical(logits=logits)
# Cache probs (mimics common usage)
_ = mc.probs
print("cached probs sum:", float(_.sum()))
# Re-apply masking (mimics SB3-contrib per-step masking)
mask = th.zeros((1, n), dtype=th.bool)
mask[0, 0] = True
# Expected: should update logits/mask and continue
# Observed: ValueError(Simplex) due to validation of cached probs during re-init
mc.apply_masking(mask)Observed
Typical output (example):
Categorical probs sum: 1.0000020265 err: 2.026e-06- then:
ValueError: Expected parameter probs (Tensor of shape (1, 992)) ... to satisfy the constraint Simplex()
Expected
apply_masking() should not crash due to validation of stale cached probs. It should behave like reinitializing from logits only.
Environment
- OS: Windows 11 (10.0.26200)
- Python: 3.13.11
- torch: 2.9.1+cu130
- sb3_contrib: 2.8.0a0
- stable_baselines3: 2.8.0a2
- gymnasium: 1.2.3
- numpy: 2.2.6
Root cause hypothesis
Torch caches probs on a distribution instance after .probs is accessed.
When MaskableCategorical.apply_masking() calls Categorical.__init__ again on the same object, torch validates existing instance fields and ends up validating cached probs (Simplex), not just the new logits.
With many categories, float32 softmax sum can deviate from exactly 1 by > 1e-6, triggering the strict Simplex check.
Proposed fix
Before reinitializing the Categorical base class inside apply_masking(), clear cached probability fields:
# Clear cached lazy_property values so torch validate_args does not validate stale probs
self.__dict__.pop("probs", None)
self.__dict__.pop("_probs", None)
super().__init__(logits=logits)Alternative (less ideal): re-init with validate_args=False.
Additional context
In a real MaskablePPO evaluation (992-discrete actions + dynamic action masks), this shows up as a rare but reproducible crash during model.predict(..., action_masks=mask, deterministic=True).
To Reproduce
import torch as th
from torch.distributions import Categorical
from sb3_contrib.common.maskable.distributions import MaskableCategorical
print("torch Distribution._validate_args:", th.distributions.Distribution._validate_args)
n = 992
# Construct logits that produce a softmax sum error > 1e-6 in float32
# (enough for torch.constraints.simplex.check() to fail in torch 2.9)
delta = 17
logits = th.full((1, n), -delta, dtype=th.float32)
logits[0, 0] = 0.0
# Sanity: plain torch Categorical is OK
p = Categorical(logits=logits).probs
print("Categorical probs sum:", float(p.sum()), "err:", float((p.sum(dim=-1) - 1).abs().max()))
# Create MaskableCategorical
mc = MaskableCategorical(logits=logits)
# Cache probs (mimics common usage)
_ = mc.probs
print("cached probs sum:", float(_.sum()))
# Re-apply masking (mimics SB3-contrib per-step masking)
mask = th.zeros((1, n), dtype=th.bool)
mask[0, 0] = True
# Expected: should update logits/mask and continue
# Observed: ValueError(Simplex) due to validation of cached probs during re-init
mc.apply_masking(mask)Relevant log output / Error message
System Info
No response
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I've used the markdown code blocks for both code and stack traces.