Skip to content

[Versioning] Use pyvers instead of builtins #3124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ support multiple versions of gym without requiring any effort from the user side
For example, considering that our virtual environment has the v0.26.2 installed,
the following function will return ``1`` when queried:

>>> from torchrl._utils import implement_for
>>> from pyvers import implement_for
>>> @implement_for("gym", None, "0.26.0")
... def fun():
... return 0
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"packaging",
"cloudpickle",
"tensordict>=0.10.0,<0.11.0",
"pyvers",
]

[project.optional-dependencies]
Expand Down
9 changes: 2 additions & 7 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@
import pytest
import torch
import torch.cuda
from pyvers import implement_for
from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torch import nn, vmap

from torchrl._utils import (
implement_for,
logger,
logger as torchrl_logger,
RL_WARNINGS,
seed_generator,
)
from torchrl._utils import logger, logger as torchrl_logger, RL_WARNINGS, seed_generator
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
Expand Down
137 changes: 136 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import torch

from packaging import version
from pyvers import implement_for
from tensordict import (
assert_allclose_td,
is_tensor_collection,
Expand All @@ -47,7 +48,7 @@
)
from torch import nn

from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors.collectors import SyncDataCollector
from torchrl.data import (
Binary,
Expand Down Expand Up @@ -1649,6 +1650,140 @@ def _test_resetting_strategies(self, heterogeneous, kwargs):
del env
gc.collect()

def test_is_from_pixels_simple_env(self):
"""Test that _is_from_pixels correctly identifies non-pixel environments."""
from torchrl.envs.libs.gym import _is_from_pixels

# Test with a simple environment that doesn't have pixels
class SimpleEnv:
def __init__(self):
try:
import gymnasium as gym
except ImportError:
import gym
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(3,))

env = SimpleEnv()

# This should return False since it's not a pixel environment
result = _is_from_pixels(env)
assert result is False, f"Expected False for simple environment, got {result}"

def test_is_from_pixels_box_env(self):
"""Test that _is_from_pixels correctly identifies pixel Box environments."""
from torchrl.envs.libs.gym import _is_from_pixels

# Test with a pixel-like environment
class PixelEnv:
def __init__(self):
try:
import gymnasium as gym
except ImportError:
import gym
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(64, 64, 3)
)

pixel_env = PixelEnv()

# This should return True since it's a pixel environment
result = _is_from_pixels(pixel_env)
assert result is True, f"Expected True for pixel environment, got {result}"

def test_is_from_pixels_dict_env(self):
"""Test that _is_from_pixels correctly identifies Dict environments with pixels."""
from torchrl.envs.libs.gym import _is_from_pixels

# Test with a Dict environment that has pixels
class DictPixelEnv:
def __init__(self):
try:
import gymnasium as gym
except ImportError:
import gym
self.observation_space = gym.spaces.Dict(
{
"pixels": gym.spaces.Box(low=0, high=255, shape=(64, 64, 3)),
"state": gym.spaces.Box(low=-1, high=1, shape=(3,)),
}
)

dict_pixel_env = DictPixelEnv()

# This should return True since it has a "pixels" key
result = _is_from_pixels(dict_pixel_env)
assert (
result is True
), f"Expected True for Dict environment with pixels, got {result}"

def test_is_from_pixels_dict_env_no_pixels(self):
"""Test that _is_from_pixels correctly identifies Dict environments without pixels."""
from torchrl.envs.libs.gym import _is_from_pixels

# Test with a Dict environment that doesn't have pixels
class DictNoPixelEnv:
def __init__(self):
try:
import gymnasium as gym
except ImportError:
import gym
self.observation_space = gym.spaces.Dict(
{
"state": gym.spaces.Box(low=-1, high=1, shape=(3,)),
"features": gym.spaces.Box(low=0, high=1, shape=(5,)),
}
)

dict_no_pixel_env = DictNoPixelEnv()

# This should return False since it doesn't have a "pixels" key
result = _is_from_pixels(dict_no_pixel_env)
assert (
result is False
), f"Expected False for Dict environment without pixels, got {result}"

def test_is_from_pixels_wrapper_env(self):
"""Test that _is_from_pixels correctly identifies wrapped environments."""
from torchrl.envs.libs.gym import _is_from_pixels

# Test with a mock environment that simulates being wrapped with a pixel wrapper
class MockWrappedEnv:
def __init__(self):
try:
import gymnasium as gym
except ImportError:
import gym
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(64, 64, 3)
)

# Mock the isinstance check to simulate the wrapper detection
import torchrl.envs.libs.utils

original_isinstance = isinstance

def mock_isinstance(obj, cls):
if cls == torchrl.envs.libs.utils.GymPixelObservationWrapper:
return True
return original_isinstance(obj, cls)

# Temporarily patch isinstance
import builtins

builtins.isinstance = mock_isinstance

try:
wrapped_env = MockWrappedEnv()

# This should return True since it's detected as a pixel wrapper
result = _is_from_pixels(wrapped_env)
assert (
result is True
), f"Expected True for wrapped environment, got {result}"
finally:
# Restore original isinstance
builtins.isinstance = original_isinstance


@pytest.mark.skipif(
not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found"
Expand Down
3 changes: 2 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
else:
from _utils_internal import capture_log_records, get_default_devices
from packaging import version
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
from pyvers import implement_for
from torchrl._utils import _rng_decorator, get_binary_env_var

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

Expand Down
2 changes: 0 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from torchrl._utils import (
auto_unwrap_transformed_env,
compile_with_warmup,
implement_for,
logger,
set_auto_unwrap_transformed_env,
timeit,
Expand Down Expand Up @@ -108,7 +107,6 @@ def _inv(self):
__all__ = [
"auto_unwrap_transformed_env",
"compile_with_warmup",
"implement_for",
"set_auto_unwrap_transformed_env",
"timeit",
"logger",
Expand Down
Loading
Loading