Skip to content

Commit d64bcb4

Browse files
authored
Fix exception cause in base_class.py (#940)
1 parent 7ce7b6a commit d64bcb4

File tree

8 files changed

+20
-17
lines changed

8 files changed

+20
-17
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Bug Fixes:
3232
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
3333
- Added a check for unbounded actions
3434
- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
35+
- Fix exception causes all over the codebase (@cool-RR)
3536

3637
Deprecations:
3738
^^^^^^^^^^^^^
@@ -978,4 +979,4 @@ And all the contributors:
978979
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
979980
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
980981
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
981-
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG
982+
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR

stable_baselines3/common/base_class.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,11 +628,11 @@ def set_parameters(
628628
attr = None
629629
try:
630630
attr = recursive_getattr(self, name)
631-
except Exception:
631+
except Exception as e:
632632
# What errors recursive_getattr could throw? KeyError, but
633633
# possible something else too (e.g. if key is an int?).
634634
# Catch anything for now.
635-
raise ValueError(f"Key {name} is an invalid object name.")
635+
raise ValueError(f"Key {name} is an invalid object name.") from e
636636

637637
if isinstance(attr, th.optim.Optimizer):
638638
# Optimizers do not support "strict" keyword...

stable_baselines3/common/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,12 @@ def _on_step(self) -> bool:
380380
if self.model.get_vec_normalize_env() is not None:
381381
try:
382382
sync_envs_normalization(self.training_env, self.eval_env)
383-
except AttributeError:
383+
except AttributeError as e:
384384
raise AssertionError(
385385
"Training and eval env are not wrapped the same way, "
386386
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
387387
"and warning above."
388-
)
388+
) from e
389389

390390
# Reset success rate buffer
391391
self._is_success_buffer = []

stable_baselines3/common/env_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
147147
try:
148148
_check_obs(obs[key], observation_space.spaces[key], "reset")
149149
except AssertionError as e:
150-
raise AssertionError(f"Error while checking key={key}: " + str(e))
150+
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
151151
else:
152152
_check_obs(obs, observation_space, "reset")
153153

@@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
166166
try:
167167
_check_obs(obs[key], observation_space.spaces[key], "step")
168168
except AssertionError as e:
169-
raise AssertionError(f"Error while checking key={key}: " + str(e))
169+
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
170170

171171
else:
172172
_check_obs(obs, observation_space, "step")

stable_baselines3/common/noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def __init__(self, base_noise: ActionNoise, n_envs: int):
105105
try:
106106
self.n_envs = int(n_envs)
107107
assert self.n_envs > 0
108-
except (TypeError, AssertionError):
109-
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
108+
except (TypeError, AssertionError) as e:
109+
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
110110

111111
self.base_noise = base_noise
112112
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def _convert_train_freq(self) -> None:
157157

158158
try:
159159
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
160-
except ValueError:
161-
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
160+
except ValueError as e:
161+
raise ValueError(
162+
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
163+
) from e
162164

163165
if not isinstance(train_freq[0], int):
164166
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

stable_baselines3/common/save_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
206206
mode = mode.lower()
207207
try:
208208
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
209-
except KeyError:
210-
raise ValueError("Expected mode to be either 'w' or 'r'.")
209+
except KeyError as e:
210+
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
211211
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
212212
e1 = "writable" if "w" == mode else "readable"
213213
raise ValueError(f"Expected a {e1} file.")
@@ -441,7 +441,7 @@ def load_from_zip_file(
441441
# State dicts. Store into params dictionary
442442
# with same name as in .zip file (without .pth)
443443
params[os.path.splitext(file_path)[0]] = th_object
444-
except zipfile.BadZipFile:
444+
except zipfile.BadZipFile as e:
445445
# load_path wasn't a zip file
446-
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
446+
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
447447
return data, params, pytorch_variables

stable_baselines3/her/her_replay_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in
2828
if current_max_episode_length is None:
2929
raise AttributeError
3030
# if not available check if a valid value was passed as an argument
31-
except AttributeError:
31+
except AttributeError as e:
3232
raise ValueError(
3333
"The max episode length could not be inferred.\n"
3434
"You must specify a `max_episode_steps` when registering the environment,\n"
3535
"use a `gym.wrappers.TimeLimit` wrapper "
3636
"or pass `max_episode_length` to the model constructor"
37-
)
37+
) from e
3838
return current_max_episode_length
3939

4040

0 commit comments

Comments
 (0)