Skip to content

Commit 411ff69

Browse files
adamfrlyaraffin
andauthored
Ensure train/n_updates metric accounts for early stopping of training loop (#1311)
* Correct _n_updates when target_kl stops loop early * Update changelog * Simplify code --------- Co-authored-by: Antonin Raffin <[email protected]>
1 parent d0c1a87 commit 411ff69

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ New Features:
2626
Bug Fixes:
2727
^^^^^^^^^^
2828
- Fixed Atari wrapper that missed the reset condition (@luizapozzobon)
29+
- Fixed PPO train/n_updates metric not accounting for early stopping (@adamfrly)
2930

3031
Deprecations:
3132
^^^^^^^^^^^^^

stable_baselines3/ppo/ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ def train(self) -> None:
189189
clip_fractions = []
190190

191191
continue_training = True
192-
193192
# train for n_epochs epochs
194193
for epoch in range(self.n_epochs):
195194
approx_kl_divs = []
@@ -271,10 +270,10 @@ def train(self) -> None:
271270
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
272271
self.policy.optimizer.step()
273272

273+
self._n_updates += 1
274274
if not continue_training:
275275
break
276276

277-
self._n_updates += self.n_epochs
278277
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
279278

280279
# Logs

0 commit comments

Comments
 (0)