diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b9bffb8c..52c862d1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Do not call ``forward()`` method directly in ``RecurrentPPO`` Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/common/maskable/distributions.py b/sb3_contrib/common/maskable/distributions.py index 52d1d709..b1bf92e2 100644 --- a/sb3_contrib/common/maskable/distributions.py +++ b/sb3_contrib/common/maskable/distributions.py @@ -110,9 +110,10 @@ class MaskableCategoricalDistribution(MaskableDistribution): :param action_dim: Number of discrete actions """ + distribution: MaskableCategorical + def __init__(self, action_dim: int): super().__init__() - self.distribution: MaskableCategorical | None = None self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index bdc3b85c..0c32e97a 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -239,7 +239,7 @@ def collect_rollouts( # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) episode_starts = th.tensor(self._last_episode_starts, dtype=th.float32, device=self.device) - actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + actions, values, log_probs, lstm_states = self.policy(obs_tensor, lstm_states, episode_starts) actions = actions.cpu().numpy()