Skip to content

Commit dba0baa

Browse files
authored
Fix mypy error (#2067)
* Fix mypy error * Ignore new errors
1 parent 57e8b97 commit dba0baa

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

stable_baselines3/common/save_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = No
181181
@functools.singledispatch
182182
def open_path(
183183
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
184-
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
184+
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom]:
185185
"""
186186
Opens a path for reading or writing with a preferred suffix and raises debug information.
187187
If the provided path is a derivative of io.BufferedIOBase it ensures that the file

stable_baselines3/common/vec_env/vec_normalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[
241241
assert self.norm_obs_keys is not None
242242
# Only normalize the specified keys
243243
for key in self.norm_obs_keys:
244-
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
244+
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) # type: ignore[call-overload]
245245
else:
246246
assert isinstance(self.obs_rms, RunningMeanStd)
247247
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
@@ -265,7 +265,7 @@ def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Unio
265265
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
266266
assert self.norm_obs_keys is not None
267267
for key in self.norm_obs_keys:
268-
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
268+
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) # type: ignore[call-overload]
269269
else:
270270
assert isinstance(self.obs_rms, RunningMeanStd)
271271
obs_ = self._unnormalize_obs(obs, self.obs_rms)

tests/test_vec_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder
1818

1919
try:
20-
import moviepy
20+
import moviepy # noqa: F401
2121

2222
have_moviepy = True
2323
except ImportError:

0 commit comments

Comments
 (0)