From 9f2e1e4d54ce8ae592c573e294c51fbc33afbbe1 Mon Sep 17 00:00:00 2001 From: flowers Date: Mon, 26 Jun 2023 18:46:05 +0200 Subject: [PATCH 1/2] fix Framestack obs env_checker --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/env_checker.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ebf27a5a3e..cfff4128eb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -50,6 +50,7 @@ Bug Fixes: for ``Inf`` and ``NaN`` (@lutogniew) - Fixed HER ``truncate_last_trajectory()`` (@lbergmann1) - Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz) +- Fixed env_checker for FrameStack observation (@corentinlger) Deprecations: ^^^^^^^^^^^^^ @@ -1361,4 +1362,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 +@lutogniew @lbergmann1 @corentinlger diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 516d7ba61d..db45277e7f 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -4,6 +4,7 @@ import gymnasium as gym import numpy as np from gymnasium import spaces +from gymnasium.wrappers.frame_stack import LazyFrames from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -171,7 +172,9 @@ def _check_goal_env_compute_reward( assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" -def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: +def _check_obs( + obs: Union[tuple, dict, np.ndarray, int, LazyFrames], observation_space: spaces.Space, method_name: str +) -> None: """ Check that the observation returned by the environment correspond to the declared one. @@ -187,7 +190,10 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # `sample()` will return a np.int64 instead of an int assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): - assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" + # Check if obs is a ndarray or a FrameStacking of ndarrays + assert isinstance( + obs, (np.ndarray, LazyFrames) + ), f"The observation returned by `{method_name}()` method must be a numpy array" # Additional checks for numpy arrays, so the error message is clearer (see GH#1399) if isinstance(obs, np.ndarray): From aa56133327203a2e2df43d69d29173ca3ff939ba Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 5 Aug 2023 18:14:23 +0200 Subject: [PATCH 2/2] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/misc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 336e1b5a1c..fd208d6afb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -93,7 +93,7 @@ Bug Fixes: for ``Inf`` and ``NaN`` (@lutogniew) - Fixed HER ``truncate_last_trajectory()`` (@lbergmann1) - Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz) -- Fixed env_checker for FrameStack observation (@corentinlger) +- Fixed ``env_checker`` for ``FrameStack`` observation (@corentinlger) Deprecations: ^^^^^^^^^^^^^