Skip to content

Commit a95b15c

Browse files
jadechogharibsprengerpre-commit-ci[bot]AdilZouitine
authored
refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)
* refactor(env): introduce explicit gym ID handling in EnvConfig/factory This commit introduces properties for the gym package/ID associated with and environment config. They default to the current defaults (`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow for easier use of external gym environments. Subclasses of `EnvConfig` can override the default properties to allow the factory to import (i.e. register) the gym env from a specific module, and also instantiate the env from any ID string. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more changes * quality * fix test --------- Co-authored-by: Ben Sprenger <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adil Zouitine <[email protected]>
1 parent a97d078 commit a95b15c

File tree

3 files changed

+70
-8
lines changed

3 files changed

+70
-8
lines changed

src/lerobot/envs/configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
3737
def type(self) -> str:
3838
return self.get_choice_name(self.__class__)
3939

40+
@property
41+
def package_name(self) -> str:
42+
"""Package name to import if environment not found in gym registry"""
43+
return f"gym_{self.type}"
44+
45+
@property
46+
def gym_id(self) -> str:
47+
"""ID string used in gym.make() to instantiate the environment"""
48+
return f"{self.package_name}/{self.task}"
49+
4050
@property
4151
@abc.abstractmethod
4252
def gym_kwargs(self) -> dict:

src/lerobot/envs/factory.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import importlib
1717

1818
import gymnasium as gym
19+
from gymnasium.envs.registration import registry as gym_registry
1920

2021
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
2122

@@ -84,17 +85,24 @@ def make_env(
8485
gym_kwargs=cfg.gym_kwargs,
8586
env_cls=env_cls,
8687
)
87-
package_name = f"gym_{cfg.type}"
88-
try:
89-
importlib.import_module(package_name)
90-
except ModuleNotFoundError as e:
91-
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
92-
raise e
9388

94-
gym_handle = f"{package_name}/{cfg.task}"
89+
if cfg.gym_id not in gym_registry:
90+
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
91+
try:
92+
importlib.import_module(cfg.package_name)
93+
except ModuleNotFoundError as e:
94+
raise ModuleNotFoundError(
95+
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
96+
f"Please install it or check PYTHONPATH."
97+
) from e
98+
99+
if cfg.gym_id not in gym_registry:
100+
raise gym.error.NameNotFound(
101+
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
102+
)
95103

96104
def _make_one():
97-
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
105+
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
98106

99107
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
100108

tests/envs/test_envs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import importlib
17+
from dataclasses import dataclass, field
1718

1819
import gymnasium as gym
1920
import pytest
2021
import torch
22+
from gymnasium.envs.registration import register, registry as gym_registry
2123
from gymnasium.utils.env_checker import check_env
2224

2325
import lerobot
26+
from lerobot.configs.types import PolicyFeature
27+
from lerobot.envs.configs import EnvConfig
2428
from lerobot.envs.factory import make_env, make_env_config
2529
from lerobot.envs.utils import preprocess_observation
2630
from tests.utils import require_env
@@ -64,3 +68,43 @@ def test_factory(env_name):
6468
assert img.min() >= 0.0
6569

6670
env.close()
71+
72+
73+
def test_factory_custom_gym_id():
74+
gym_id = "dummy_gym_pkg/DummyTask-v0"
75+
if gym_id in gym_registry:
76+
pytest.skip(f"Environment ID {gym_id} is already registered")
77+
78+
@EnvConfig.register_subclass("dummy")
79+
@dataclass
80+
class DummyEnv(EnvConfig):
81+
task: str = "DummyTask-v0"
82+
fps: int = 10
83+
features: dict[str, PolicyFeature] = field(default_factory=dict)
84+
85+
@property
86+
def package_name(self) -> str:
87+
return "dummy_gym_pkg"
88+
89+
@property
90+
def gym_id(self) -> str:
91+
return gym_id
92+
93+
@property
94+
def gym_kwargs(self) -> dict:
95+
return {}
96+
97+
try:
98+
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
99+
100+
cfg = DummyEnv()
101+
envs_dict = make_env(cfg, n_envs=1)
102+
dummy_envs = envs_dict["dummy"]
103+
assert len(dummy_envs) == 1
104+
env = next(iter(dummy_envs.values()))
105+
assert env is not None and isinstance(env, gym.vector.VectorEnv)
106+
env.close()
107+
108+
finally:
109+
if gym_id in gym_registry:
110+
del gym_registry[gym_id]

0 commit comments

Comments
 (0)