diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fd209dea8b..ebb08794f5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -99,6 +99,7 @@ Bug Fixes: for ``Inf`` and ``NaN`` (@lutogniew) - Fixed HER ``truncate_last_trajectory()`` (@lbergmann1) - Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz) +- Fixed ``env_checker`` for ``FrameStack`` observation (@corentinlger) Deprecations: ^^^^^^^^^^^^^ @@ -1410,4 +1411,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @corentinlger \ No newline at end of file diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 8b8da7f440..8af46bffce 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -4,6 +4,7 @@ import gymnasium as gym import numpy as np from gymnasium import spaces +from gymnasium.wrappers.frame_stack import LazyFrames from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -171,7 +172,9 @@ def _check_goal_env_compute_reward( assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" -def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: +def _check_obs( + obs: Union[tuple, dict, np.ndarray, int, LazyFrames], observation_space: spaces.Space, method_name: str +) -> None: """ Check that the observation returned by the environment correspond to the declared one. @@ -187,7 +190,10 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # `sample()` will return a np.int64 instead of an int assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): - assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" + # Check if obs is a ndarray or a FrameStacking of ndarrays + assert isinstance( + obs, (np.ndarray, LazyFrames) + ), f"The observation returned by `{method_name}()` method must be a numpy array" # Additional checks for numpy arrays, so the error message is clearer (see GH#1399) if isinstance(obs, np.ndarray):