@@ -11,44 +11,40 @@ class VecCheckNan(VecEnvWrapper):
11
11
allowing you to know from what the NaN of inf originated from.
12
12
13
13
:param venv: the vectorized environment to wrap
14
- :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
15
- :param warn_once: Whether or not to only warn once.
16
- :param check_inf: Whether or not to check for +inf or -inf as well
14
+ :param raise_exception: Whether to raise a ValueError, instead of a UserWarning
15
+ :param warn_once: Whether to only warn once.
16
+ :param check_inf: Whether to check for +inf or -inf as well
17
17
"""
18
18
19
- def __init__ (self , venv : VecEnv , raise_exception : bool = False , warn_once : bool = True , check_inf : bool = True ):
20
- VecEnvWrapper .__init__ (self , venv )
19
+ def __init__ (self , venv : VecEnv , raise_exception : bool = False , warn_once : bool = True , check_inf : bool = True ) -> None :
20
+ super () .__init__ (venv )
21
21
self .raise_exception = raise_exception
22
22
self .warn_once = warn_once
23
23
self .check_inf = check_inf
24
- self ._actions = None
25
- self ._observations = None
24
+
26
25
self ._user_warned = False
27
26
28
- def step_async ( self , actions : np .ndarray ) -> None :
29
- self ._check_val ( async_step = True , actions = actions )
27
+ self . _actions : np .ndarray
28
+ self ._observations : VecEnvObs
30
29
30
+ def step_async (self , actions : np .ndarray ) -> None :
31
+ self ._check_val (event = "step_async" , actions = actions )
31
32
self ._actions = actions
32
33
self .venv .step_async (actions )
33
34
34
35
def step_wait (self ) -> VecEnvStepReturn :
35
- observations , rewards , news , infos = self .venv .step_wait ()
36
-
37
- self ._check_val (async_step = False , observations = observations , rewards = rewards , news = news )
38
-
36
+ observations , rewards , dones , infos = self .venv .step_wait ()
37
+ self ._check_val (event = "step_wait" , observations = observations , rewards = rewards , dones = dones )
39
38
self ._observations = observations
40
- return observations , rewards , news , infos
39
+ return observations , rewards , dones , infos
41
40
42
41
def reset (self ) -> VecEnvObs :
43
42
observations = self .venv .reset ()
44
- self ._actions = None
45
-
46
- self ._check_val (async_step = False , observations = observations )
47
-
43
+ self ._check_val (event = "reset" , observations = observations )
48
44
self ._observations = observations
49
45
return observations
50
46
51
- def _check_val (self , * , async_step : bool , ** kwargs ) -> None :
47
+ def _check_val (self , event : str , ** kwargs ) -> None :
52
48
# if warn and warn once and have warned once: then stop checking
53
49
if not self .raise_exception and self .warn_once and self ._user_warned :
54
50
return
@@ -72,13 +68,14 @@ def _check_val(self, *, async_step: bool, **kwargs) -> None:
72
68
73
69
msg += ".\r \n Originated from the "
74
70
75
- if not async_step :
76
- if self ._actions is None :
77
- msg += "environment observation (at reset)"
78
- else :
79
- msg += f"environment, Last given value was: \r \n \t action={ self ._actions } "
80
- else :
71
+ if event == "reset" :
72
+ msg += "environment observation (at reset)"
73
+ elif event == "step_wait" :
74
+ msg += f"environment, Last given value was: \r \n \t action={ self ._actions } "
75
+ elif event == "step_async" :
81
76
msg += f"RL model, Last given value was: \r \n \t observations={ self ._observations } "
77
+ else :
78
+ raise ValueError ("Internal error." )
82
79
83
80
if self .raise_exception :
84
81
raise ValueError (msg )
0 commit comments