Skip to content

Commit b1cc159

Browse files
authored
Use higher resolution time_ns() and avoid division by zero (#979)
* Use higher resolution time and round up to eps * Update changelog * Add test case * Fix formatting, time()->time_ns * Bugfix: ns is integer not float * Move test to better place * Divide by 1e9 earlier
1 parent fda3d4d commit b1cc159

File tree

6 files changed

+27
-6
lines changed

6 files changed

+27
-6
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ SB3-Contrib
1818
Bug Fixes:
1919
^^^^^^^^^^
2020
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
21+
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
2122

2223
Deprecations:
2324
^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def _setup_learn(
422422
:param tb_log_name: the name of the run for tensorboard log
423423
:return:
424424
"""
425-
self.start_time = time.time()
425+
self.start_time = time.time_ns()
426426

427427
if self.ep_info_buffer is None or reset_num_timesteps:
428428
# Initialize buffers if they don't exist, or reinitialize if resetting counters

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import pathlib
3+
import sys
34
import time
45
import warnings
56
from copy import deepcopy
@@ -427,8 +428,8 @@ def _dump_logs(self) -> None:
427428
"""
428429
Write log.
429430
"""
430-
time_elapsed = time.time() - self.start_time
431-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
431+
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
432+
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
432433
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
433434
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
434435
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))

stable_baselines3/common/on_policy_algorithm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import time
23
from typing import Any, Dict, List, Optional, Tuple, Type, Union
34

@@ -254,13 +255,14 @@ def learn(
254255

255256
# Display training infos
256257
if log_interval is not None and iteration % log_interval == 0:
257-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
258+
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
259+
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
258260
self.logger.record("time/iterations", iteration, exclude="tensorboard")
259261
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
260262
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
261263
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
262264
self.logger.record("time/fps", fps)
263-
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
265+
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
264266
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
265267
self.logger.dump(step=self.num_timesteps)
266268

tests/test_logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import time
33
from typing import Sequence
4+
from unittest import mock
45

56
import gym
67
import numpy as np
@@ -381,3 +382,16 @@ def test_fps_logger(tmp_path, algo):
381382
# third time, FPS should be the same
382383
model.learn(100, log_interval=1, reset_num_timesteps=False)
383384
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
385+
386+
387+
@pytest.mark.parametrize("algo", [A2C, DQN])
388+
def test_fps_no_div_zero(algo):
389+
"""Set time to constant and train algorithm to check no division by zero error.
390+
391+
Time can appear to be constant during short runs on platforms with low-precision
392+
timers. We should avoid division by zero errors e.g. when computing FPS in
393+
this situation."""
394+
with mock.patch("time.time", lambda: 42.0):
395+
with mock.patch("time.time_ns", lambda: 42.0):
396+
model = algo("MlpPolicy", "CartPole-v1")
397+
model.learn(total_timesteps=100)

tests/test_run.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
@pytest.mark.parametrize("model_class", [TD3, DDPG])
13-
@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
13+
@pytest.mark.parametrize(
14+
"action_noise",
15+
[normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))],
16+
)
1417
def test_deterministic_pg(model_class, action_noise):
1518
"""
1619
Test for DDPG and variants (TD3).

0 commit comments

Comments
 (0)