Skip to content

Commit eeedda8

Browse files
Removed entropy term (#1)
1 parent 23cb60b commit eeedda8

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import gymnasium as gym
3-
from stable_baselines3 import PPO
3+
from stable_baselines3 import PPO, RSPPO
44
from stable_baselines3.common.utils import set_random_seed
55

66
set_random_seed(42)
77

88
env = gym.make('CartPole-v1')
9-
model = PPO('MlpPolicy', env, verbose=1)
9+
model = RSPPO('MlpPolicy', env, verbose=1)
1010

1111
model.learn(total_timesteps=1e6)

stable_baselines3/rsppo/rsppo.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class RSPPO(OnPolicyAlgorithm):
4646
no clipping will be done on the value function.
4747
IMPORTANT: this clipping depends on the reward scaling.
4848
:param normalize_advantage: Whether to normalize or not the advantage
49-
:param ent_coef: Entropy coefficient for the loss calculation
5049
:param vf_coef: Value function coefficient for the loss calculation
5150
:param max_grad_norm: The maximum value for the gradient clipping
5251
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
@@ -90,7 +89,6 @@ def __init__(
9089
clip_range: Union[float, Schedule] = 0.2,
9190
clip_range_vf: Union[None, float, Schedule] = None,
9291
normalize_advantage: bool = True,
93-
ent_coef: float = 0.0,
9492
vf_coef: float = 0.5,
9593
max_grad_norm: float = 0.5,
9694
use_sde: bool = False,
@@ -113,7 +111,6 @@ def __init__(
113111
n_steps=n_steps,
114112
gamma=gamma,
115113
gae_lambda=gae_lambda,
116-
ent_coef=ent_coef,
117114
vf_coef=vf_coef,
118115
max_grad_norm=max_grad_norm,
119116
use_sde=use_sde,
@@ -126,6 +123,7 @@ def __init__(
126123
verbose=verbose,
127124
device=device,
128125
seed=seed,
126+
ent_coef=0,
129127
_init_setup_model=False,
130128
supported_action_spaces=(
131129
spaces.Box,
@@ -195,7 +193,6 @@ def train(self) -> None:
195193
if self.clip_range_vf is not None:
196194
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
197195

198-
entropy_losses = []
199196
pg_losses, value_losses = [], []
200197
clip_fractions = []
201198

@@ -210,7 +207,7 @@ def train(self) -> None:
210207
# Convert discrete action from float to long
211208
actions = rollout_data.actions.long().flatten()
212209

213-
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
210+
values, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions)
214211
values = values.flatten()
215212
# Normalize advantage
216213
advantages = rollout_data.advantages
@@ -246,16 +243,7 @@ def train(self) -> None:
246243
value_loss = F.mse_loss(rollout_data.returns, values_pred)
247244
value_losses.append(value_loss.item())
248245

249-
# Entropy loss favor exploration
250-
if entropy is None:
251-
# Approximate entropy when no analytical form
252-
entropy_loss = -th.mean(-log_prob)
253-
else:
254-
entropy_loss = -th.mean(entropy)
255-
256-
entropy_losses.append(entropy_loss.item())
257-
258-
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
246+
loss = policy_loss + self.vf_coef * value_loss
259247

260248
# Calculate approximate form of reverse KL Divergence for early stopping
261249
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
@@ -286,7 +274,6 @@ def train(self) -> None:
286274
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
287275

288276
# Logs
289-
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
290277
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
291278
self.logger.record("train/value_loss", np.mean(value_losses))
292279
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))

0 commit comments

Comments
 (0)