diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 7aa6fd7a42c..45dc3a811bd 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1395,7 +1395,7 @@ support multiple versions of gym without requiring any effort from the user side For example, considering that our virtual environment has the v0.26.2 installed, the following function will return ``1`` when queried: - >>> from torchrl._utils import implement_for + >>> from pyvers import implement_for >>> @implement_for("gym", None, "0.26.0") ... def fun(): ... return 0 diff --git a/pyproject.toml b/pyproject.toml index 9a6014d269d..d4b0db96694 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "packaging", "cloudpickle", "tensordict>=0.10.0,<0.11.0", + "pyvers", ] [project.optional-dependencies] diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 79922fe1917..e7914f60335 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -17,17 +17,12 @@ import pytest import torch import torch.cuda +from pyvers import implement_for from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase from tensordict.nn import TensorDictModuleBase from torch import nn, vmap -from torchrl._utils import ( - implement_for, - logger, - logger as torchrl_logger, - RL_WARNINGS, - seed_generator, -) +from torchrl._utils import logger, logger as torchrl_logger, RL_WARNINGS, seed_generator from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv diff --git a/test/test_libs.py b/test/test_libs.py index fb1066cc361..f94bd7dd086 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -34,6 +34,7 @@ import torch from packaging import version +from pyvers import implement_for from tensordict import ( assert_allclose_td, is_tensor_collection, @@ -47,7 +48,7 @@ ) from torch import nn -from torchrl._utils import implement_for, logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.collectors import SyncDataCollector from torchrl.data import ( Binary, @@ -1649,6 +1650,140 @@ def _test_resetting_strategies(self, heterogeneous, kwargs): del env gc.collect() + def test_is_from_pixels_simple_env(self): + """Test that _is_from_pixels correctly identifies non-pixel environments.""" + from torchrl.envs.libs.gym import _is_from_pixels + + # Test with a simple environment that doesn't have pixels + class SimpleEnv: + def __init__(self): + try: + import gymnasium as gym + except ImportError: + import gym + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) + + env = SimpleEnv() + + # This should return False since it's not a pixel environment + result = _is_from_pixels(env) + assert result is False, f"Expected False for simple environment, got {result}" + + def test_is_from_pixels_box_env(self): + """Test that _is_from_pixels correctly identifies pixel Box environments.""" + from torchrl.envs.libs.gym import _is_from_pixels + + # Test with a pixel-like environment + class PixelEnv: + def __init__(self): + try: + import gymnasium as gym + except ImportError: + import gym + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(64, 64, 3) + ) + + pixel_env = PixelEnv() + + # This should return True since it's a pixel environment + result = _is_from_pixels(pixel_env) + assert result is True, f"Expected True for pixel environment, got {result}" + + def test_is_from_pixels_dict_env(self): + """Test that _is_from_pixels correctly identifies Dict environments with pixels.""" + from torchrl.envs.libs.gym import _is_from_pixels + + # Test with a Dict environment that has pixels + class DictPixelEnv: + def __init__(self): + try: + import gymnasium as gym + except ImportError: + import gym + self.observation_space = gym.spaces.Dict( + { + "pixels": gym.spaces.Box(low=0, high=255, shape=(64, 64, 3)), + "state": gym.spaces.Box(low=-1, high=1, shape=(3,)), + } + ) + + dict_pixel_env = DictPixelEnv() + + # This should return True since it has a "pixels" key + result = _is_from_pixels(dict_pixel_env) + assert ( + result is True + ), f"Expected True for Dict environment with pixels, got {result}" + + def test_is_from_pixels_dict_env_no_pixels(self): + """Test that _is_from_pixels correctly identifies Dict environments without pixels.""" + from torchrl.envs.libs.gym import _is_from_pixels + + # Test with a Dict environment that doesn't have pixels + class DictNoPixelEnv: + def __init__(self): + try: + import gymnasium as gym + except ImportError: + import gym + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Box(low=-1, high=1, shape=(3,)), + "features": gym.spaces.Box(low=0, high=1, shape=(5,)), + } + ) + + dict_no_pixel_env = DictNoPixelEnv() + + # This should return False since it doesn't have a "pixels" key + result = _is_from_pixels(dict_no_pixel_env) + assert ( + result is False + ), f"Expected False for Dict environment without pixels, got {result}" + + def test_is_from_pixels_wrapper_env(self): + """Test that _is_from_pixels correctly identifies wrapped environments.""" + from torchrl.envs.libs.gym import _is_from_pixels + + # Test with a mock environment that simulates being wrapped with a pixel wrapper + class MockWrappedEnv: + def __init__(self): + try: + import gymnasium as gym + except ImportError: + import gym + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(64, 64, 3) + ) + + # Mock the isinstance check to simulate the wrapper detection + import torchrl.envs.libs.utils + + original_isinstance = isinstance + + def mock_isinstance(obj, cls): + if cls == torchrl.envs.libs.utils.GymPixelObservationWrapper: + return True + return original_isinstance(obj, cls) + + # Temporarily patch isinstance + import builtins + + builtins.isinstance = mock_isinstance + + try: + wrapped_env = MockWrappedEnv() + + # This should return True since it's detected as a pixel wrapper + result = _is_from_pixels(wrapped_env) + assert ( + result is True + ), f"Expected True for wrapped environment, got {result}" + finally: + # Restore original isinstance + builtins.isinstance = original_isinstance + @pytest.mark.skipif( not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found" diff --git a/test/test_utils.py b/test/test_utils.py index 98cb23adadc..7daf7a0f757 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -23,7 +23,8 @@ else: from _utils_internal import capture_log_records, get_default_devices from packaging import version -from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for +from pyvers import implement_for +from torchrl._utils import _rng_decorator, get_binary_env_var from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 2032f5175c5..a0f103ab01e 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -49,7 +49,6 @@ from torchrl._utils import ( auto_unwrap_transformed_env, compile_with_warmup, - implement_for, logger, set_auto_unwrap_transformed_env, timeit, @@ -108,7 +107,6 @@ def _inv(self): __all__ = [ "auto_unwrap_transformed_env", "compile_with_warmup", - "implement_for", "set_auto_unwrap_transformed_env", "timeit", "logger", diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 043c0298c1a..882c5acc033 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -17,15 +17,12 @@ import traceback import warnings from contextlib import nullcontext -from copy import copy from functools import wraps -from importlib import import_module from textwrap import indent from typing import Any, Callable, cast, TypeVar import numpy as np import torch -from packaging.version import parse from tensordict import unravel_key from tensordict.utils import NestedKey from torch import multiprocessing as mp, Tensor @@ -389,263 +386,6 @@ def __repr__(self): _CKPT_BACKEND = _Dynamic_CKPT_BACKEND() -class implement_for: - """A version decorator that checks the version in the environment and implements a function with the fitting one. - - If specified module is missing or there is no fitting implementation, call of the decorated function - will lead to the explicit error. - In case of intersected ranges, last fitting implementation is used. - - This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, - numpy vs jax-numpy etc). - - Args: - module_name (str or callable): version is checked for the module with this - name (e.g. "gym"). If a callable is provided, it should return the - module. - from_version: version from which implementation is compatible. Can be open (None). - to_version: version from which implementation is no longer compatible. Can be open (None). - - Keyword Args: - class_method (bool, optional): if ``True``, the function will be written as a class method. - Defaults to ``False``. - compilable (bool, optional): If ``False``, the module import happens - only on the first call to the wrapped function. If ``True``, the - module import happens when the wrapped function is initialized. This - allows the wrapped function to work well with ``torch.compile``. - Defaults to ``False``. - - Examples: - >>> @implement_for("gym", "0.13", "0.14") - >>> def fun(self, x): - ... # Older gym versions will return x + 1 - ... return x + 1 - ... - >>> @implement_for("gym", "0.14", "0.23") - >>> def fun(self, x): - ... # More recent gym versions will return x + 2 - ... return x + 2 - ... - >>> @implement_for(lambda: import_module("gym"), "0.23", None) - >>> def fun(self, x): - ... # More recent gym versions will return x + 2 - ... return x + 2 - ... - >>> @implement_for("gymnasium", None, "1.0.0") - >>> def fun(self, x): - ... # If gymnasium is to be used instead of gym, x+3 will be returned - ... return x + 3 - ... - - This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. - """ - - # Stores pointers to fitting implementations: dict[func_name] = func_pointer - _implementations = {} - _setters = [] - _cache_modules = {} - - def __init__( - self, - module_name: str | Callable, - from_version: str = None, - to_version: str = None, - *, - class_method: bool = False, - compilable: bool = False, - ): - self.module_name = module_name - self.from_version = from_version - self.to_version = to_version - self.class_method = class_method - self._compilable = compilable - implement_for._setters.append(self) - - @staticmethod - def check_version(version: str, from_version: str | None, to_version: str | None): - version = parse(".".join([str(v) for v in parse(version).release])) - return (from_version is None or version >= parse(from_version)) and ( - to_version is None or version < parse(to_version) - ) - - @staticmethod - def get_class_that_defined_method(f): - """Returns the class of a method, if it is defined, and None otherwise.""" - out = f.__globals__.get(f.__qualname__.split(".")[0], None) - return out - - @classmethod - def get_func_name(cls, fn): - # produces a name like torchrl.module.Class.method or torchrl.module.function - fn_str = str(fn).split(".") - if fn_str[0].startswith(" str: - """Imports module and returns its version.""" - if not callable(module_name): - module = cls._cache_modules.get(module_name, None) - if module is None: - if module_name in sys.modules: - sys.modules[module_name] = module = import_module(module_name) - else: - cls._cache_modules[module_name] = module = import_module( - module_name - ) - else: - module = module_name() - return module.__version__ - - _lazy_impl = collections.defaultdict(list) - - def _delazify(self, func_name): - out = None - for local_call in implement_for._lazy_impl[func_name]: - out = local_call() - return out - - def __call__(self, fn): - # function names are unique - self.func_name = self.get_func_name(fn) - self.fn = fn - implement_for._lazy_impl[self.func_name].append(self._call) - - if self._compilable: - _call_fn = self._delazify(self.func_name) - - if self.class_method: - return classmethod(_call_fn) - - return _call_fn - else: - - @wraps(fn) - def _lazy_call_fn(*args, **kwargs): - # first time we call the function, we also do the replacement. - # This will cause the imports to occur only during the first call to fn - - result = self._delazify(self.func_name)(*args, **kwargs) - return result - - if self.class_method: - return classmethod(_lazy_call_fn) - - return _lazy_call_fn - - def _call(self): - - # If the module is missing replace the function with the mock. - fn = self.fn - func_name = self.func_name - implementations = implement_for._implementations - - @wraps(fn) - def unsupported(*args, **kwargs): - raise ModuleNotFoundError( - f"Supported version of '{func_name}' has not been found." - ) - - self.do_set = False - # Return fitting implementation if it was encountered before. - if func_name in implementations: - try: - # check that backends don't conflict - version = self.import_module(self.module_name) - if self.check_version(version, self.from_version, self.to_version): - if VERBOSE: - module = import_module(self.module_name) - warnings.warn( - f"Got multiple backends for {func_name}. " - f"Using the last queried ({module} with version {version})." - ) - self.do_set = True - if not self.do_set: - return implementations[func_name].fn - except ModuleNotFoundError: - # then it's ok, there is no conflict - return implementations[func_name].fn - else: - try: - version = self.import_module(self.module_name) - if self.check_version(version, self.from_version, self.to_version): - self.do_set = True - except ModuleNotFoundError: - return unsupported - if self.do_set: - self.module_set() - return fn - return unsupported - - @classmethod - def reset(cls, setters_dict: dict[str, implement_for] = None): - """Resets the setters in setter_dict. - - ``setter_dict`` is a copy of implementations. We just need to iterate through its - values and call :meth:`module_set` for each. - - """ - if VERBOSE: - logger.info("resetting implement_for") - if setters_dict is None: - setters_dict = copy(cls._implementations) - for setter in setters_dict.values(): - setter.module_set() - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"module_name={self.module_name}({self.from_version, self.to_version}), " - f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)})" - ) - - def accept_remote_rref_invocation(func): """Decorator that allows a method to be invoked remotely. diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d89aad0fc9c..fb4e037c666 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -18,6 +18,7 @@ import numpy as np import tensordict import torch +from pyvers import implement_for from tensordict import ( is_tensor_collection, lazy_stack, @@ -31,7 +32,7 @@ from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten -from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger +from torchrl._utils import _make_ordinal_device, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( CompressedListStorageCheckpointer, ListStorageCheckpointer, diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 97c62bf9707..30768cab525 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -16,6 +16,7 @@ import numpy as np import torch +from pyvers import implement_for from tensordict import ( lazy_stack, MemoryMappedTensor, @@ -27,7 +28,7 @@ from torch import Tensor from torch.nn import functional as F from torch.utils._pytree import LeafSpec, tree_flatten, tree_unflatten -from torchrl._utils import implement_for, logger as torchrl_logger +from torchrl._utils import logger as torchrl_logger SINGLE_TENSOR_BUFFER_NAME = os.environ.get( "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_" diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7819146594d..dfcf655c9ad 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -34,6 +34,7 @@ import tensordict import torch +from pyvers import implement_for from tensordict import ( is_tensor_collection, lazy_stack, @@ -52,7 +53,7 @@ is_non_tensor, NestedKey, ) -from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for +from torchrl._utils import _make_ordinal_device, get_binary_env_var try: from torch.compiler import is_compiling diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 3d1eceb2f14..8549952d4c6 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,6 +16,7 @@ import numpy as np import torch import torch.nn as nn +from pyvers import implement_for from tensordict import ( is_tensor_collection, LazyStackedTensorDict, @@ -28,7 +29,6 @@ _ends_with, _make_ordinal_device, _replace_last, - implement_for, prod, seed_generator, ) diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py index 3f6292932da..e16ac5cfd96 100644 --- a/torchrl/envs/libs/_gym_utils.py +++ b/torchrl/envs/libs/_gym_utils.py @@ -7,11 +7,11 @@ import importlib.util import torch + +from pyvers import implement_for from tensordict.utils import unravel_key from torch.utils._pytree import tree_map - -from torchrl._utils import implement_for from torchrl.data.tensor_specs import Composite from torchrl.envs import step_mdp, TransformedEnv from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 371969417b6..e140c24cdee 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -8,7 +8,6 @@ import collections import importlib import warnings -from copy import copy from types import ModuleType from typing import Dict from warnings import warn @@ -16,10 +15,15 @@ import numpy as np import torch from packaging import version + +from pyvers import ( + BackendManager, + gym_backend as pyvers_gym_backend, + implement_for, + register_backend, +) from tensordict import TensorDict, TensorDictBase from torch.utils._pytree import tree_map - -from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( _minmax_dtype, Binary, @@ -39,11 +43,7 @@ from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv from torchrl.envs.utils import _classproperty -try: - from torch.utils._contextlib import _DecoratorContextManager -except ModuleNotFoundError: - from torchrl._utils import _DecoratorContextManager - +# Keep these for backward compatibility with existing code DEFAULT_GYM = None IMPORT_ERROR = None # check gym presence without importing it @@ -56,6 +56,14 @@ _has_isaaclab = importlib.util.find_spec("isaaclab") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None +# Register gym backends with pyvers +register_backend( + "gym", + { + "gym": "gym", + "gymnasium": "gymnasium", + }, +) GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 versions due to incompatible changes in the Gym API. @@ -80,7 +88,7 @@ def _minigrid_lib(): return minigrid -class set_gym_backend(_DecoratorContextManager): +class set_gym_backend(BackendManager): """Sets the gym-backend to a certain value. Args: @@ -124,71 +132,7 @@ class set_gym_backend(_DecoratorContextManager): """ def __init__(self, backend): - self.backend = backend - - def _call(self): - """Sets the backend as default.""" - global DEFAULT_GYM - DEFAULT_GYM = self.backend - found_setters = collections.defaultdict(bool) - for setter in copy(implement_for._setters): - check_module = ( - callable(setter.module_name) - and setter.module_name.__name__ == self.backend.__name__ - ) or setter.module_name == self.backend.__name__ - check_version = setter.check_version( - self.backend.__version__, setter.from_version, setter.to_version - ) - if check_module and check_version: - setter.module_set() - found_setter = True - elif check_module: - found_setter = False - else: - found_setter = None - if found_setter is not None: - found_setters[setter.func_name] = ( - found_setters[setter.func_name] or found_setter - ) - # we keep only the setters we need. This is safe because a copy is saved under self._setters_saved - for func_name, found_setter in found_setters.items(): - if not found_setter: - raise ImportError( - f"could not set anything related to gym backend " - f"{self.backend.__name__} with version={self.backend.__version__} for the function with name {func_name}. " - f"Check that the gym versions match!" - ) - - def set(self): - """Irreversibly sets the gym backend in the script.""" - self._call() - - def __enter__(self): - # we save a complete list of setters as well as whether they should be set. - # we want the full list because we want to be able to nest the calls to set_gym_backend. - # we also want to keep track of which ones are set to reproduce what was set before. - self._setters_saved = copy(implement_for._implementations) - self._call() - - def __exit__(self, exc_type, exc_val, exc_tb): - implement_for.reset(setters_dict=self._setters_saved) - delattr(self, "_setters_saved") - - def clone(self): - # override this method if your children class takes __init__ parameters - return self.__class__(self.backend) - - @property - def backend(self): - if isinstance(self._backend, str): - return importlib.import_module(self._backend) - elif callable(self._backend): - return self._backend() - return self._backend - - @backend.setter - def backend(self, value): - self._backend = value + super().__init__("gym", backend) def gym_backend(submodule=None): @@ -207,26 +151,42 @@ def gym_backend(submodule=None): ... wrappers = gym_backend('wrappers') ... print(wrappers) """ - global IMPORT_ERROR - global DEFAULT_GYM - if DEFAULT_GYM is None: - try: - # rule of thumbs: gymnasium precedes - import gymnasium as gym - except ImportError as err: - IMPORT_ERROR = err + # Get the current backend from pyvers + try: + backend = pyvers_gym_backend() + except ImportError: + backend = None + + if backend is None: + global IMPORT_ERROR + global DEFAULT_GYM + if DEFAULT_GYM is None: try: - import gym as gym + # rule of thumbs: gymnasium precedes + import gymnasium as gym except ImportError as err: IMPORT_ERROR = err - gym = None - DEFAULT_GYM = gym - if submodule is not None: - if not submodule.startswith("."): - submodule = "." + submodule - submodule = importlib.import_module(submodule, package=DEFAULT_GYM.__name__) - return submodule - return DEFAULT_GYM + try: + import gym as gym + except ImportError as err: + IMPORT_ERROR = err + gym = None + DEFAULT_GYM = gym + backend = DEFAULT_GYM + + if backend is None: + raise ImportError("No gym backend could be loaded") + + if submodule is None: + return backend + + # Get the submodule + try: + return getattr(backend, submodule) + except AttributeError: + raise AttributeError( + f"Module {backend.__name__} has no attribute '{submodule}'" + ) __all__ = ["GymWrapper", "GymEnv"] @@ -747,17 +707,55 @@ def _get_gym_envs(): # noqa: F811 return gym.envs.registration.registry.keys() +@implement_for("gym") def _is_from_pixels(env): observation_spec = env.observation_space - try: - PixelObservationWrapper = gym_backend( - "wrappers.pixel_observation" - ).PixelObservationWrapper - except ModuleNotFoundError: + from torchrl.envs.libs.utils import ( + GymPixelObservationWrapper as LegacyPixelObservationWrapper, + ) - class PixelObservationWrapper: - pass + gDict = gym_backend("spaces").dict.Dict + Box = gym_backend("spaces").Box + # Check if it's a gymnasium Dict space + if isinstance(observation_spec, (gDict,)): + if "pixels" in set(observation_spec.spaces.keys()): + return True + # Check if it's a pixel-like Box space + elif ( + isinstance(observation_spec, Box) + and (observation_spec.low == 0).all() + and (observation_spec.high == 255).all() + and observation_spec.low.shape[-1] == 3 + and observation_spec.low.ndim == 3 + ): + return True + else: + while True: + # For gym, try PixelObservationWrapper + try: + PixelObservationWrapper = gym_backend( + "wrappers.pixel_observation" + ).PixelObservationWrapper + if isinstance(env, PixelObservationWrapper): + return True + except (ModuleNotFoundError, AttributeError): + pass + + # Check our custom wrapper + if isinstance(env, LegacyPixelObservationWrapper): + return True + + if hasattr(env, "env"): + env = env.env + else: + break + return False + + +@implement_for("gymnasium", None, "1.1.0") +def _is_from_pixels(env): # noqa: F811 + observation_spec = env.observation_space from torchrl.envs.libs.utils import ( GymPixelObservationWrapper as LegacyPixelObservationWrapper, ) @@ -765,12 +763,57 @@ class PixelObservationWrapper: gDict = gym_backend("spaces").dict.Dict Box = gym_backend("spaces").Box - if isinstance(observation_spec, (Dict,)): - if "pixels" in set(observation_spec.keys()): + # Check if it's a gymnasium Dict space + if isinstance(observation_spec, (gDict,)): + if "pixels" in set(observation_spec.spaces.keys()): return True + # Check if it's a pixel-like Box space + elif ( + isinstance(observation_spec, Box) + and (observation_spec.low == 0).all() + and (observation_spec.high == 255).all() + and observation_spec.low.shape[-1] == 3 + and observation_spec.low.ndim == 3 + ): + return True + else: + while True: + # For gymnasium < 1.1.0, try PixelObservationWrapper + try: + PixelObservationWrapper = gym_backend( + "wrappers.pixel_observation" + ).PixelObservationWrapper + if isinstance(env, PixelObservationWrapper): + return True + except (ModuleNotFoundError, AttributeError): + pass + + # Check our custom wrapper + if isinstance(env, LegacyPixelObservationWrapper): + return True + + if hasattr(env, "env"): + env = env.env + else: + break + return False + + +@implement_for("gymnasium", "1.1.0") +def _is_from_pixels(env): # noqa: F811 + observation_spec = env.observation_space + from torchrl.envs.libs.utils import ( + GymPixelObservationWrapper as LegacyPixelObservationWrapper, + ) + + gDict = gym_backend("spaces").dict.Dict + Box = gym_backend("spaces").Box + + # Check if it's a gymnasium Dict space if isinstance(observation_spec, (gDict,)): if "pixels" in set(observation_spec.spaces.keys()): return True + # Check if it's a pixel-like Box space elif ( isinstance(observation_spec, Box) and (observation_spec.low == 0).all() @@ -781,10 +824,18 @@ class PixelObservationWrapper: return True else: while True: - if isinstance( - env, (LegacyPixelObservationWrapper, PixelObservationWrapper) - ): + # For gymnasium >= 1.1.0, use AddRenderObservation + try: + AddRenderObservation = gym_backend("wrappers").AddRenderObservation + if isinstance(env, AddRenderObservation): + return True + except (ModuleNotFoundError, AttributeError): + pass + + # Check our custom wrapper + if isinstance(env, LegacyPixelObservationWrapper): return True + if hasattr(env, "env"): env = env.env else: