Skip to content

[Bug]: MaskableCategorical.apply_masking triggers Simplex validate_args error with torch 2.9 #322

@feeedback

Description

@feeedback

🐛 Bug

MaskableCategorical.apply_masking reinitializes Categorical and triggers validate_args on cached probs -> ValueError(Simplex) (torch 2.9)

Bug description

sb3_contrib.common.maskable.distributions.MaskableCategorical.apply_masking() calls torch.distributions.Categorical.__init__ again on the same distribution instance:

# sb3_contrib/common/maskable/distributions.py
super().__init__(logits=logits)

With torch.distributions.Distribution._validate_args = True (default in my env), this can raise:

ValueError: Expected parameter probs (...) to satisfy the constraint Simplex(), but found invalid values

This happens when:

  1. MaskableCategorical.probs is accessed (common in deterministic action selection / .mode() etc), which caches the probs tensor on the instance
  2. later apply_masking() is called again on the same instance (common: action mask changes every step)
  3. during the re-init, torch validates the cached probs against constraints.Simplex() and fails when softmax sum drift is > 1e-6 in float32 (with many categories, e.g. 992)

Note: vanilla torch.distributions.Categorical(logits=...) does not validate Simplex in this path (logits are validated as real vectors), so it does not crash.

Minimal reproduction (no SB3 model / no env needed)

This reproducer only needs torch + sb3_contrib.

import torch as th
from torch.distributions import Categorical
from sb3_contrib.common.maskable.distributions import MaskableCategorical

print("torch Distribution._validate_args:", th.distributions.Distribution._validate_args)

n = 992
# Construct logits that produce a softmax sum error > 1e-6 in float32
# (enough for torch.constraints.simplex.check() to fail in torch 2.9)
delta = 17
logits = th.full((1, n), -delta, dtype=th.float32)
logits[0, 0] = 0.0

# Sanity: plain torch Categorical is OK
p = Categorical(logits=logits).probs
print("Categorical probs sum:", float(p.sum()), "err:", float((p.sum(dim=-1) - 1).abs().max()))

# Create MaskableCategorical
mc = MaskableCategorical(logits=logits)

# Cache probs (mimics common usage)
_ = mc.probs
print("cached probs sum:", float(_.sum()))

# Re-apply masking (mimics SB3-contrib per-step masking)
mask = th.zeros((1, n), dtype=th.bool)
mask[0, 0] = True

# Expected: should update logits/mask and continue
# Observed: ValueError(Simplex) due to validation of cached probs during re-init
mc.apply_masking(mask)

Observed

Typical output (example):

  • Categorical probs sum: 1.0000020265 err: 2.026e-06
  • then:
ValueError: Expected parameter probs (Tensor of shape (1, 992)) ... to satisfy the constraint Simplex()

Expected

apply_masking() should not crash due to validation of stale cached probs. It should behave like reinitializing from logits only.

Environment

  • OS: Windows 11 (10.0.26200)
  • Python: 3.13.11
  • torch: 2.9.1+cu130
  • sb3_contrib: 2.8.0a0
  • stable_baselines3: 2.8.0a2
  • gymnasium: 1.2.3
  • numpy: 2.2.6

Root cause hypothesis

Torch caches probs on a distribution instance after .probs is accessed.
When MaskableCategorical.apply_masking() calls Categorical.__init__ again on the same object, torch validates existing instance fields and ends up validating cached probs (Simplex), not just the new logits.

With many categories, float32 softmax sum can deviate from exactly 1 by > 1e-6, triggering the strict Simplex check.

Proposed fix

Before reinitializing the Categorical base class inside apply_masking(), clear cached probability fields:

# Clear cached lazy_property values so torch validate_args does not validate stale probs
self.__dict__.pop("probs", None)
self.__dict__.pop("_probs", None)

super().__init__(logits=logits)

Alternative (less ideal): re-init with validate_args=False.

Additional context

In a real MaskablePPO evaluation (992-discrete actions + dynamic action masks), this shows up as a rare but reproducible crash during model.predict(..., action_masks=mask, deterministic=True).

To Reproduce

import torch as th
from torch.distributions import Categorical
from sb3_contrib.common.maskable.distributions import MaskableCategorical

print("torch Distribution._validate_args:", th.distributions.Distribution._validate_args)

n = 992
# Construct logits that produce a softmax sum error > 1e-6 in float32
# (enough for torch.constraints.simplex.check() to fail in torch 2.9)
delta = 17
logits = th.full((1, n), -delta, dtype=th.float32)
logits[0, 0] = 0.0

# Sanity: plain torch Categorical is OK
p = Categorical(logits=logits).probs
print("Categorical probs sum:", float(p.sum()), "err:", float((p.sum(dim=-1) - 1).abs().max()))

# Create MaskableCategorical
mc = MaskableCategorical(logits=logits)

# Cache probs (mimics common usage)
_ = mc.probs
print("cached probs sum:", float(_.sum()))

# Re-apply masking (mimics SB3-contrib per-step masking)
mask = th.zeros((1, n), dtype=th.bool)
mask[0, 0] = True

# Expected: should update logits/mask and continue
# Observed: ValueError(Simplex) due to validation of cached probs during re-init
mc.apply_masking(mask)

Relevant log output / Error message

System Info

No response

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions