Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Bug Fixes:
for ``Inf`` and ``NaN`` (@lutogniew)
- Fixed HER ``truncate_last_trajectory()`` (@lbergmann1)
- Fixed HER desired and achieved goal order in reward computation (@JonathanKuelz)
- Fixed env_checker for FrameStack observation (@corentinlger)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1361,4 +1362,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1
@lutogniew @lbergmann1 @corentinlger
10 changes: 8 additions & 2 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.wrappers.frame_stack import LazyFrames

from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
Expand Down Expand Up @@ -171,7 +172,9 @@ def _check_goal_env_compute_reward(
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"


def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
def _check_obs(
obs: Union[tuple, dict, np.ndarray, int, LazyFrames], observation_space: spaces.Space, method_name: str
) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
Expand All @@ -187,7 +190,10 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
# `sample()` will return a np.int64 instead of an int
assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
# Check if obs is a ndarray or a FrameStacking of ndarrays
assert isinstance(
obs, (np.ndarray, LazyFrames)
), f"The observation returned by `{method_name}()` method must be a numpy array"

# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
Expand Down