diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index da009e2c7..c682fbe41 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -183,6 +183,59 @@ This callback can then be used to safely modify environment attributes during tr it calls the environment setter method. +Checking VecEnv Implementation +------------------------------ + +When implementing custom vectorized environments, it's easy to make mistakes that can lead to hard-to-debug issues. +To help with this, Stable-Baselines3 provides a ``check_vecenv`` function that validates your VecEnv implementation +and checks for common issues. + +The ``check_vecenv`` function verifies: + +* The VecEnv properly inherits from ``stable_baselines3.common.vec_env.VecEnv`` +* Required attributes (``num_envs``, ``observation_space``, ``action_space``) are present and valid +* The ``reset()`` method returns observations with the correct vectorized shape (batch dimension first) +* The ``step()`` method returns properly shaped observations, rewards, dones, and infos +* All return values have the expected types and dimensions +* Compatibility with Stable-Baselines3 algorithms + +**Usage:** + +.. code-block:: python + + from stable_baselines3.common.vec_env import DummyVecEnv + from stable_baselines3.common.vec_env_checker import check_vecenv + import gymnasium as gym + + def make_env(): + return gym.make('CartPole-v1') + + # Create your VecEnv + vec_env = DummyVecEnv([make_env for _ in range(4)]) + + # Check the VecEnv implementation + check_vecenv(vec_env, warn=True) + + vec_env.close() + +**When to use:** + +* When implementing a custom VecEnv class +* When debugging issues with vectorized environments +* When contributing new VecEnv implementations to ensure they follow the API +* As a sanity check before training to catch potential issues early + +**Note:** Similar to ``check_env`` for single environments, ``check_vecenv`` is particularly useful during development +and debugging. It helps catch common vectorization mistakes like incorrect batch dimensions, wrong return types, or +missing required methods. + + +VecEnv Checker +~~~~~~~~~~~~~~ + +.. autofunction:: stable_baselines3.common.vec_env_checker.check_vecenv + + Vectorized Environments Wrappers -------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 73c2ab202..a173ff40d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - ``RolloutBuffer`` and ``DictRolloutBuffer`` now uses the actual observation / action space ``dtype`` (instead of float32), this should save memory (@Trenza1ore) +- Added ``check_vecenv()`` function to check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3 (@copilot) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index ac49a0469..bcbecc62f 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -13,6 +13,9 @@ from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder +# Avoid circular import by importing the vec_env_checker here +from stable_baselines3.common.vec_env_checker import check_vecenv + VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) @@ -98,6 +101,7 @@ def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None: "VecNormalize", "VecTransposeImage", "VecVideoRecorder", + "check_vecenv", "is_vecenv_wrapped", "sync_envs_normalization", "unwrap_vec_normalize", diff --git a/stable_baselines3/common/vec_env_checker.py b/stable_baselines3/common/vec_env_checker.py new file mode 100644 index 000000000..c614a5815 --- /dev/null +++ b/stable_baselines3/common/vec_env_checker.py @@ -0,0 +1,182 @@ +import warnings +from typing import Any + +import numpy as np +from gymnasium import spaces + +from stable_baselines3.common.env_checker import _check_box_obs, _check_unsupported_spaces +from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + +def _check_vecenv_spaces(vec_env: VecEnv) -> None: + """ + Check that the VecEnv has valid observation and action spaces. + """ + assert hasattr(vec_env, "observation_space"), "VecEnv must have an observation_space attribute" + assert hasattr(vec_env, "action_space"), "VecEnv must have an action_space attribute" + assert hasattr(vec_env, "num_envs"), "VecEnv must have a num_envs attribute" + + assert isinstance( + vec_env.observation_space, spaces.Space + ), f"The observation space must inherit from gymnasium.spaces, got {type(vec_env.observation_space)}" + assert isinstance( + vec_env.action_space, spaces.Space + ), f"The action space must inherit from gymnasium.spaces, got {type(vec_env.action_space)}" + assert ( + isinstance(vec_env.num_envs, int) and vec_env.num_envs > 0 + ), f"num_envs must be a positive integer, got {vec_env.num_envs} (type: {type(vec_env.num_envs)})" + + +def _check_vecenv_reset(vec_env: VecEnv) -> Any: + """ + Check that VecEnv reset method works correctly and returns properly shaped observations. + """ + obs = vec_env.reset() + + # Check observation shape matches expected vectorized shape + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs, *vec_env.observation_space.shape) + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, reset() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs, *space.shape) + assert ( + obs[key].shape == expected_shape + ), f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + elif isinstance(vec_env.observation_space, spaces.Discrete): + assert isinstance(obs, np.ndarray), f"For Discrete observation space, reset() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs,) + assert obs.shape == expected_shape, f"Expected observation shape {expected_shape}, got {obs.shape}" + + return obs + + +def _check_vecenv_step(vec_env: VecEnv, obs: Any) -> None: + """ + Check that VecEnv step method works correctly and returns properly shaped values. + """ + # Generate valid actions + actions = np.array([vec_env.action_space.sample() for _ in range(vec_env.num_envs)]) + + obs, rewards, dones, infos = vec_env.step(actions) + + # Check rewards + assert isinstance(rewards, np.ndarray), f"step() must return rewards as np.ndarray, got {type(rewards)}" + assert rewards.shape == (vec_env.num_envs,), f"Expected rewards shape ({vec_env.num_envs},), got {rewards.shape}" + + # Check dones + assert isinstance(dones, np.ndarray), f"step() must return dones as np.ndarray, got {type(dones)}" + assert dones.shape == (vec_env.num_envs,), f"Expected dones shape ({vec_env.num_envs},), got {dones.shape}" + assert dones.dtype == bool, f"dones must have dtype bool, got {dones.dtype}" + + # Check infos + assert isinstance(infos, (list, tuple)), f"step() must return infos as list or tuple, got {type(infos)}" + assert len(infos) == vec_env.num_envs, f"Expected infos length {vec_env.num_envs}, got {len(infos)}" + for i, info in enumerate(infos): + assert isinstance(info, dict), f"infos[{i}] must be dict, got {type(info)}" + + # Check observation shape consistency (similar to reset) + if isinstance(vec_env.observation_space, spaces.Box): + assert isinstance(obs, np.ndarray), f"For Box observation space, step() must return np.ndarray, got {type(obs)}" + expected_shape = (vec_env.num_envs, *vec_env.observation_space.shape) + assert obs.shape == expected_shape, ( + f"Expected observation shape {expected_shape}, got {obs.shape}. " + f"VecEnv observations should have batch dimension first." + ) + elif isinstance(vec_env.observation_space, spaces.Dict): + assert isinstance(obs, dict), f"For Dict observation space, step() must return dict, got {type(obs)}" + for key, space in vec_env.observation_space.spaces.items(): + assert key in obs, f"Missing key '{key}' in observation dict" + if isinstance(space, spaces.Box): + expected_shape = (vec_env.num_envs, *space.shape) + assert ( + obs[key].shape == expected_shape + ), f"Expected observation['{key}'] shape {expected_shape}, got {obs[key].shape}" + + +class _DummyVecEnvForSpaceCheck: + """Dummy class to pass to _check_unsupported_spaces function.""" + + def __init__(self, observation_space: spaces.Space, action_space: spaces.Space): + self.observation_space = observation_space + self.action_space = action_space + + +def _check_vecenv_unsupported_spaces(observation_space: spaces.Space, action_space: spaces.Space) -> bool: + """ + Emit warnings when the observation space or action space used is not supported by Stable-Baselines + for VecEnv. Reuses the existing _check_unsupported_spaces function. + + :return: True if return value tests should be skipped. + """ + # Create a dummy env object to pass to the existing function + dummy_env = _DummyVecEnvForSpaceCheck(observation_space, action_space) + return _check_unsupported_spaces(dummy_env, observation_space, action_space) # type: ignore[arg-type] + + +def check_vecenv(vec_env: VecEnv, warn: bool = True) -> None: + """ + Check that a VecEnv follows the VecEnv API and is compatible with Stable-Baselines3. + + This checker verifies that: + - The VecEnv has proper observation_space, action_space, and num_envs attributes + - The reset() method returns observations with correct vectorized shape + - The step() method returns observations, rewards, dones, and infos with correct shapes + - All return values have the expected types and dimensions + + :param vec_env: The vectorized environment to check + :param warn: Whether to output additional warnings mainly related to + the interaction with Stable Baselines + """ + assert isinstance(vec_env, VecEnv), "Your environment must inherit from stable_baselines3.common.vec_env.VecEnv" + + # ============= Check basic VecEnv attributes ================ + _check_vecenv_spaces(vec_env) + + # Define aliases for convenience + observation_space = vec_env.observation_space + action_space = vec_env.action_space + + # Warn the user if needed - reuse existing space checking logic + if warn: + should_skip = _check_vecenv_unsupported_spaces(observation_space, action_space) + if should_skip: + warnings.warn("VecEnv contains unsupported spaces, skipping some checks") + return + + obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} + for key, space in obs_spaces.items(): + if isinstance(space, spaces.Box): + _check_box_obs(space, key) + + # Check for the action space + if isinstance(action_space, spaces.Box) and ( + np.any(np.abs(action_space.low) != np.abs(action_space.high)) + or np.any(action_space.low != -1) + or np.any(action_space.high != 1) + ): + warnings.warn( + "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " + "cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" + ) + + if isinstance(action_space, spaces.Box): + assert np.all( + np.isfinite(np.array([action_space.low, action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + + if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): + warnings.warn( + f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." + ) + + # ============ Check the VecEnv methods =============== + obs = _check_vecenv_reset(vec_env) + _check_vecenv_step(vec_env, obs) diff --git a/tests/test_vec_env_checker.py b/tests/test_vec_env_checker.py new file mode 100644 index 000000000..f9444d1cc --- /dev/null +++ b/tests/test_vec_env_checker.py @@ -0,0 +1,196 @@ +import gymnasium as gym +import numpy as np +import pytest +from gymnasium import spaces + +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv +from stable_baselines3.common.vec_env_checker import check_vecenv + + +class BrokenVecEnv: + """A broken VecEnv that doesn't inherit from VecEnv.""" + + def __init__(self): + self.num_envs = 2 + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) + self.action_space = spaces.Discrete(2) + + +class MissingAttributeVecEnv(VecEnv): + """A VecEnv missing required attributes.""" + + def __init__(self): + # Intentionally not calling super().__init__ + pass + + def reset(self): + pass + + def step_async(self, actions): + pass + + def step_wait(self): + pass + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + pass + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + pass + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * getattr(self, "num_envs", 1) + + +class WrongShapeVecEnv(VecEnv): + """A VecEnv that returns wrong-shaped observations.""" + + def __init__(self): + super().__init__( + num_envs=2, observation_space=spaces.Box(low=-1.0, high=1.0, shape=(3,)), action_space=spaces.Discrete(2) + ) + + def reset(self): + # Return wrong shape (should be (2, 3) but return (3,)) + return np.zeros(3) + + def step_async(self, actions): + pass + + def step_wait(self): + # Return wrong shapes + obs = np.zeros(3) # Should be (2, 3) + rewards = np.zeros(3) # Should be (2,) + dones = np.zeros(3) # Should be (2,) + infos = [{}] # Should be [{}, {}] - list or tuple with 2 elements + return obs, rewards, dones, infos + + def close(self): + pass + + def get_attr(self, attr_name, indices=None): + return [None] * self.num_envs + + def set_attr(self, attr_name, value, indices=None): + pass + + def env_method(self, method_name, *method_args, indices=None, **method_kwargs): + return [None] * self.num_envs + + def env_is_wrapped(self, wrapper_class, indices=None): + return [False] * self.num_envs + + +def test_check_vecenv_basic(): + """Test basic VecEnv checker functionality with a working VecEnv.""" + + def make_env(): + return gym.make("CartPole-v1") + + vec_env = DummyVecEnv([make_env for _ in range(2)]) + + try: + # Should pass without issues + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_not_vecenv(): + """Test that check_vecenv raises error for non-VecEnv objects.""" + + broken_env = BrokenVecEnv() + + with pytest.raises(AssertionError, match=r"must inherit from.*VecEnv"): + check_vecenv(broken_env) + + +def test_check_vecenv_missing_attributes(): + """Test that check_vecenv raises error for VecEnv with missing attributes.""" + + broken_env = MissingAttributeVecEnv() + + with pytest.raises(AssertionError, match=r"must have.*attribute"): + check_vecenv(broken_env) + + +def test_check_vecenv_wrong_shapes(): + """Test that check_vecenv catches wrong-shaped observations and returns.""" + + broken_env = WrongShapeVecEnv() + + try: + with pytest.raises(AssertionError, match="Expected observation shape"): + check_vecenv(broken_env) + finally: + broken_env.close() + + +def test_check_vecenv_dict_space(): + """Test VecEnv checker with Dict observation space.""" + + class DictEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Dict( + { + "observation": spaces.Box(low=-1.0, high=1.0, shape=(4,)), + "achieved_goal": spaces.Box(low=-1.0, high=1.0, shape=(2,)), + } + ) + self.action_space = spaces.Discrete(2) + + def reset(self, *, seed=None, options=None): + return { + "observation": np.zeros(4), + "achieved_goal": np.zeros(2), + }, {} + + def step(self, action): + obs = { + "observation": np.zeros(4), + "achieved_goal": np.zeros(2), + } + return obs, 0.0, False, False, {} + + def make_dict_env(): + return DictEnv() + + vec_env = DummyVecEnv([make_dict_env for _ in range(2)]) + + try: + check_vecenv(vec_env, warn=True) + finally: + vec_env.close() + + +def test_check_vecenv_warnings(): + """Test that check_vecenv emits appropriate warnings.""" + + class BoxActionEnv(gym.Env): + def __init__(self): + self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,)) + # Asymmetric action space should trigger warning + self.action_space = spaces.Box(low=-2.0, high=3.0, shape=(2,)) + + def reset(self, *, seed=None, options=None): + return np.zeros(4), {} + + def step(self, action): + return np.zeros(4), 0.0, False, False, {} + + def make_box_env(): + return BoxActionEnv() + + vec_env = DummyVecEnv([make_box_env for _ in range(2)]) + + try: + with pytest.warns(UserWarning, match="symmetric and normalized Box action space"): + check_vecenv(vec_env, warn=True) + finally: + vec_env.close()