Skip to content

Commit 0c9b793

Browse files
Removed entropy term (#1)
1 parent 23cb60b commit 0c9b793

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

stable_baselines3/rsppo/rsppo.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class RSPPO(OnPolicyAlgorithm):
4646
no clipping will be done on the value function.
4747
IMPORTANT: this clipping depends on the reward scaling.
4848
:param normalize_advantage: Whether to normalize or not the advantage
49-
:param ent_coef: Entropy coefficient for the loss calculation
5049
:param vf_coef: Value function coefficient for the loss calculation
5150
:param max_grad_norm: The maximum value for the gradient clipping
5251
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
@@ -90,7 +89,6 @@ def __init__(
9089
clip_range: Union[float, Schedule] = 0.2,
9190
clip_range_vf: Union[None, float, Schedule] = None,
9291
normalize_advantage: bool = True,
93-
ent_coef: float = 0.0,
9492
vf_coef: float = 0.5,
9593
max_grad_norm: float = 0.5,
9694
use_sde: bool = False,
@@ -113,7 +111,6 @@ def __init__(
113111
n_steps=n_steps,
114112
gamma=gamma,
115113
gae_lambda=gae_lambda,
116-
ent_coef=ent_coef,
117114
vf_coef=vf_coef,
118115
max_grad_norm=max_grad_norm,
119116
use_sde=use_sde,
@@ -195,7 +192,6 @@ def train(self) -> None:
195192
if self.clip_range_vf is not None:
196193
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
197194

198-
entropy_losses = []
199195
pg_losses, value_losses = [], []
200196
clip_fractions = []
201197

@@ -210,7 +206,7 @@ def train(self) -> None:
210206
# Convert discrete action from float to long
211207
actions = rollout_data.actions.long().flatten()
212208

213-
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
209+
values, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions)
214210
values = values.flatten()
215211
# Normalize advantage
216212
advantages = rollout_data.advantages
@@ -246,16 +242,7 @@ def train(self) -> None:
246242
value_loss = F.mse_loss(rollout_data.returns, values_pred)
247243
value_losses.append(value_loss.item())
248244

249-
# Entropy loss favor exploration
250-
if entropy is None:
251-
# Approximate entropy when no analytical form
252-
entropy_loss = -th.mean(-log_prob)
253-
else:
254-
entropy_loss = -th.mean(entropy)
255-
256-
entropy_losses.append(entropy_loss.item())
257-
258-
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
245+
loss = policy_loss + self.vf_coef * value_loss
259246

260247
# Calculate approximate form of reverse KL Divergence for early stopping
261248
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
@@ -286,7 +273,6 @@ def train(self) -> None:
286273
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
287274

288275
# Logs
289-
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
290276
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
291277
self.logger.record("train/value_loss", np.mean(value_losses))
292278
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))

0 commit comments

Comments
 (0)