Skip to content

Commit 92f7a6f

Browse files
qgallouedecaraffin
andauthored
Fix test_vec_normalize.py, test_tensorboard.py and common/monitor.py type hint (#1194)
* Remove from mypy exclude * type hint for metadata * Union[float, int] -> float * Remove useless __init__ * Type hint for model and logger in BaseCallback * Type hint for metric_dict * Update changelog * fix test_tensorboard * ignore gamma type checking * Fix monitor type hint * Update logger type hints * Fix type annotation and bump version * Fix circular import Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 9bb1538 commit 92f7a6f

File tree

12 files changed

+47
-42
lines changed

12 files changed

+47
-42
lines changed

docs/guide/tensorboard.rst

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,9 @@ Here is an example of how to save hyperparameters in TensorBoard:
268268
269269
270270
class HParamCallback(BaseCallback):
271-
def __init__(self):
272-
"""
273-
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
274-
"""
275-
super().__init__()
271+
"""
272+
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
273+
"""
276274
277275
def _on_training_start(self) -> None:
278276
hparam_dict = {
@@ -284,7 +282,7 @@ Here is an example of how to save hyperparameters in TensorBoard:
284282
# Tensorbaord will find & display metrics from the `SCALARS` tab
285283
metric_dict = {
286284
"rollout/ep_len_mean": 0,
287-
"train/value_loss": 0,
285+
"train/value_loss": 0.0,
288286
}
289287
self.logger.record(
290288
"hparams",

docs/misc/changelog.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.8.0a0 (WIP)
7+
Release 1.8.0a1 (WIP)
88
--------------------------
99

1010

@@ -28,6 +28,9 @@ Deprecations:
2828

2929
Others:
3030
^^^^^^^
31+
- Fixed ``tests/test_tensorboard.py`` type hint
32+
- Fixed ``tests/test_vec_normalize.py`` type hint
33+
- Fixed ``stable_baselines3/common/monitor.py`` type hint
3134

3235
Documentation:
3336
^^^^^^^^^^^^^^

setup.cfg

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ exclude = (?x)(
3939
| stable_baselines3/common/envs/identity_env.py$
4040
| stable_baselines3/common/envs/multi_input_envs.py$
4141
| stable_baselines3/common/logger.py$
42-
| stable_baselines3/common/monitor.py$
4342
| stable_baselines3/common/off_policy_algorithm.py$
4443
| stable_baselines3/common/on_policy_algorithm.py$
4544
| stable_baselines3/common/policies.py$
@@ -67,9 +66,7 @@ exclude = (?x)(
6766
| stable_baselines3/td3/policies.py$
6867
| stable_baselines3/td3/td3.py$
6968
| tests/test_logger.py$
70-
| tests/test_tensorboard.py$
7169
| tests/test_train_eval_mode.py$
72-
| tests/test_vec_normalize.py$
7370
)
7471

7572
[flake8]

stable_baselines3/common/callbacks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import gym
77
import numpy as np
88

9+
from stable_baselines3.common.logger import Logger
10+
911
try:
1012
from tqdm import TqdmExperimentalWarning
1113

@@ -29,10 +31,13 @@ class BaseCallback(ABC):
2931
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
3032
"""
3133

34+
# The RL model
35+
# Type hint as string to avoid circular import
36+
model: "base_class.BaseAlgorithm"
37+
logger: Logger
38+
3239
def __init__(self, verbose: int = 0):
3340
super().__init__()
34-
# The RL model
35-
self.model = None # type: Optional[base_class.BaseAlgorithm]
3641
# An alias for self.model.get_env(), the environment used for training
3742
self.training_env = None # type: Union[gym.Env, VecEnv, None]
3843
# Number of time the callback was called
@@ -42,7 +47,6 @@ def __init__(self, verbose: int = 0):
4247
self.verbose = verbose
4348
self.locals: Dict[str, Any] = {}
4449
self.globals: Dict[str, Any] = {}
45-
self.logger = None
4650
# Sometimes, for event callback, it is useful
4751
# to have access to the parent object
4852
self.parent = None # type: Optional[BaseCallback]

stable_baselines3/common/envs/multi_input_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def init_possible_transitions(self) -> None:
121121
self.right_possible = [0, 1, 2, 12, 13, 14]
122122
self.up_possible = [4, 8, 12, 7, 11, 15]
123123

124-
def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
124+
def step(self, action: Union[float, np.ndarray]) -> GymStepReturn:
125125
"""
126126
Run one timestep of the environment's dynamics. When end of
127127
episode is reached, you are responsible for calling `reset()`

stable_baselines3/common/logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile
66
import warnings
77
from collections import defaultdict
8-
from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
8+
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
99

1010
import numpy as np
1111
import pandas
@@ -16,7 +16,7 @@
1616
from torch.utils.tensorboard import SummaryWriter
1717
from torch.utils.tensorboard.summary import hparams
1818
except ImportError:
19-
SummaryWriter = None
19+
SummaryWriter = None # type: ignore[misc, assignment]
2020

2121
try:
2222
from tqdm import tqdm
@@ -38,7 +38,7 @@ class Video:
3838
:param fps: frames per second
3939
"""
4040

41-
def __init__(self, frames: th.Tensor, fps: Union[float, int]):
41+
def __init__(self, frames: th.Tensor, fps: float):
4242
self.frames = frames
4343
self.fps = fps
4444

@@ -80,7 +80,7 @@ class HParam:
8080
A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section.
8181
"""
8282

83-
def __init__(self, hparam_dict: Dict[str, Union[bool, str, float, int, None]], metric_dict: Dict[str, Union[float, int]]):
83+
def __init__(self, hparam_dict: Mapping[str, Union[bool, str, float, None]], metric_dict: Mapping[str, float]):
8484
self.hparam_dict = hparam_dict
8585
if not metric_dict:
8686
raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.")
@@ -329,7 +329,7 @@ class CSVOutputFormat(KVWriter):
329329

330330
def __init__(self, filename: str):
331331
self.file = open(filename, "w+t")
332-
self.keys = []
332+
self.keys: List[str] = []
333333
self.separator = ","
334334
self.quotechar = '"'
335335

stable_baselines3/common/monitor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import time
77
from glob import glob
8-
from typing import Dict, List, Optional, Tuple, Union
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

1010
import gym
1111
import numpy as np
@@ -41,25 +41,26 @@ def __init__(
4141
):
4242
super().__init__(env=env)
4343
self.t_start = time.time()
44+
self.results_writer = None
4445
if filename is not None:
4546
self.results_writer = ResultsWriter(
4647
filename,
4748
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
4849
extra_keys=reset_keywords + info_keywords,
4950
override_existing=override_existing,
5051
)
51-
else:
52-
self.results_writer = None
52+
5353
self.reset_keywords = reset_keywords
5454
self.info_keywords = info_keywords
5555
self.allow_early_resets = allow_early_resets
56-
self.rewards = None
56+
self.rewards: List[float] = []
5757
self.needs_reset = True
58-
self.episode_returns = []
59-
self.episode_lengths = []
60-
self.episode_times = []
58+
self.episode_returns: List[float] = []
59+
self.episode_lengths: List[int] = []
60+
self.episode_times: List[float] = []
6161
self.total_steps = 0
62-
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
62+
# extra info about the current episode, that was passed in during reset()
63+
self.current_reset_info: Dict[str, Any] = {}
6364

6465
def reset(self, **kwargs) -> GymObs:
6566
"""
@@ -200,7 +201,7 @@ def __init__(
200201

201202
self.file_handler.flush()
202203

203-
def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
204+
def write_row(self, epinfo: Dict[str, float]) -> None:
204205
"""
205206
Close the file handler
206207

stable_baselines3/common/running_mean_std.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Union
1+
from typing import Tuple
22

33
import numpy as np
44

@@ -40,7 +40,7 @@ def update(self, arr: np.ndarray) -> None:
4040
batch_count = arr.shape[0]
4141
self.update_from_moments(batch_mean, batch_var, batch_count)
4242

43-
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float]) -> None:
43+
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float) -> None:
4444
delta = batch_mean - self.mean
4545
tot_count = self.count + batch_count
4646

stable_baselines3/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
7676
param_group["lr"] = learning_rate
7777

7878

79-
def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
79+
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
8080
"""
8181
Transform (if needed) learning rate and clip range (for PPO)
8282
to callable.

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.0a0
1+
1.8.0a1

0 commit comments

Comments
 (0)