Skip to content

Commit 5549b34

Browse files
authored
Fix stable_baselines3/common/vec_env/vec_check_nan.py type hints (#1226)
* super() init style * "async_step" arg to "event"; "news" to "dones"; improve docstring * Remove vec_check_nan from mypy exclude * Update changelog
1 parent 9aff113 commit 5549b34

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Others:
5757
- Fixed ``stable_baselines3/common/env_util.py`` type hints
5858
- Fixed ``stable_baselines3/common/preprocessing.py`` type hints
5959
- Fixed ``stable_baselines3/common/atari_wrappers.py`` type hints
60+
- Fixed ``stable_baselines3/common/vec_env/vec_check_nan.py`` type hints
6061
- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong)
6162
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
6263
- Set tensors construction directly on the device (~8% speed boost on GPU)

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ exclude = (?x)(
5252
| stable_baselines3/common/vec_env/stacked_observations.py$
5353
| stable_baselines3/common/vec_env/subproc_vec_env.py$
5454
| stable_baselines3/common/vec_env/util.py$
55-
| stable_baselines3/common/vec_env/vec_check_nan.py$
5655
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
5756
| stable_baselines3/common/vec_env/vec_frame_stack.py$
5857
| stable_baselines3/common/vec_env/vec_monitor.py$

stable_baselines3/common/vec_env/vec_check_nan.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,40 @@ class VecCheckNan(VecEnvWrapper):
1111
allowing you to know from what the NaN of inf originated from.
1212
1313
:param venv: the vectorized environment to wrap
14-
:param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
15-
:param warn_once: Whether or not to only warn once.
16-
:param check_inf: Whether or not to check for +inf or -inf as well
14+
:param raise_exception: Whether to raise a ValueError, instead of a UserWarning
15+
:param warn_once: Whether to only warn once.
16+
:param check_inf: Whether to check for +inf or -inf as well
1717
"""
1818

19-
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True):
20-
VecEnvWrapper.__init__(self, venv)
19+
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None:
20+
super().__init__(venv)
2121
self.raise_exception = raise_exception
2222
self.warn_once = warn_once
2323
self.check_inf = check_inf
24-
self._actions = None
25-
self._observations = None
24+
2625
self._user_warned = False
2726

28-
def step_async(self, actions: np.ndarray) -> None:
29-
self._check_val(async_step=True, actions=actions)
27+
self._actions: np.ndarray
28+
self._observations: VecEnvObs
3029

30+
def step_async(self, actions: np.ndarray) -> None:
31+
self._check_val(event="step_async", actions=actions)
3132
self._actions = actions
3233
self.venv.step_async(actions)
3334

3435
def step_wait(self) -> VecEnvStepReturn:
35-
observations, rewards, news, infos = self.venv.step_wait()
36-
37-
self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
38-
36+
observations, rewards, dones, infos = self.venv.step_wait()
37+
self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones)
3938
self._observations = observations
40-
return observations, rewards, news, infos
39+
return observations, rewards, dones, infos
4140

4241
def reset(self) -> VecEnvObs:
4342
observations = self.venv.reset()
44-
self._actions = None
45-
46-
self._check_val(async_step=False, observations=observations)
47-
43+
self._check_val(event="reset", observations=observations)
4844
self._observations = observations
4945
return observations
5046

51-
def _check_val(self, *, async_step: bool, **kwargs) -> None:
47+
def _check_val(self, event: str, **kwargs) -> None:
5248
# if warn and warn once and have warned once: then stop checking
5349
if not self.raise_exception and self.warn_once and self._user_warned:
5450
return
@@ -72,13 +68,14 @@ def _check_val(self, *, async_step: bool, **kwargs) -> None:
7268

7369
msg += ".\r\nOriginated from the "
7470

75-
if not async_step:
76-
if self._actions is None:
77-
msg += "environment observation (at reset)"
78-
else:
79-
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
80-
else:
71+
if event == "reset":
72+
msg += "environment observation (at reset)"
73+
elif event == "step_wait":
74+
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
75+
elif event == "step_async":
8176
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
77+
else:
78+
raise ValueError("Internal error.")
8279

8380
if self.raise_exception:
8481
raise ValueError(msg)

0 commit comments

Comments
 (0)