diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 37b9e7f5..40f9368b 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -280,7 +280,12 @@ def evaluate_actions( distribution.apply_masking(action_masks) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) - return values, log_prob, distribution.entropy() + try: + entropy = distribution.entropy() + except NotImplementedError: + # When no analytical form, entropy needs to be estimated using -log_prob.mean() + entropy = -log_prob.mean(dim=1) + return values, log_prob, entropy def get_distribution(self, obs: th.Tensor, action_masks: Optional[np.ndarray] = None) -> MaskableDistribution: """ diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index cbc27181..34a29f5e 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -329,7 +329,12 @@ def evaluate_actions( distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) - return values, log_prob, distribution.entropy() + try: + entropy = distribution.entropy() + except NotImplementedError: + # When no analytical form, entropy needs to be estimated using -log_prob.mean() + entropy = -log_prob.mean(dim=1) + return values, log_prob, entropy def _predict( self, diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 78aa58f0..0a3c6cca 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -446,11 +446,7 @@ def train(self) -> None: value_losses.append(value_loss.item()) # Entropy loss favor exploration - if entropy is None: - # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob) - else: - entropy_loss = -th.mean(entropy) + entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 12a8c7f5..232a6d25 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -389,11 +389,7 @@ def train(self) -> None: value_losses.append(value_loss.item()) # Entropy loss favor exploration - if entropy is None: - # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob[mask]) - else: - entropy_loss = -th.mean(entropy[mask]) + entropy_loss = -th.mean(entropy[mask]) entropy_losses.append(entropy_loss.item())