Skip to content

Commit 08c96e6

Browse files
committed
Add gymnasium
1 parent 12f2f35 commit 08c96e6

File tree

5 files changed

+229
-1
lines changed

5 files changed

+229
-1
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from lerobot.envs.gymnasium_robotics import GymRoboticsEnv
2+
import numpy as np
3+
4+
env = GymRoboticsEnv("FetchPickAndPlace-v4")
5+
obs, info = env.reset()
6+
print({k: type(v) for k, v in obs.items()})
7+
print({k: v.shape for k, v in obs["images"].items()})
8+
print("state shape:", obs["state"].shape)
9+
print("goal in obs:", "goal" in obs)
10+
print(env.action_space)
11+
print(env.action_space.shape[0])
12+
13+
done = False
14+
while not done:
15+
action = np.zeros(env.action_space.shape, dtype=np.float32)
16+
obs, reward, terminated, truncated, info = env.step(action)
17+
done = terminated or truncated
18+
print("rollout ok")
19+
env.close()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ aloha = ["gym-aloha>=0.1.2,<0.2.0"]
133133
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
134134
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
135135
metaworld = ["metaworld==3.0.0"]
136+
gymnasium-robotics = ["gymnasium-robotics>=1.4.1"]
136137

137138
# All
138139
all = [
@@ -155,6 +156,7 @@ all = [
155156
"lerobot[phone]",
156157
"lerobot[libero]",
157158
"lerobot[metaworld]",
159+
"lerobot[gymnasium-robotics]",
158160
]
159161

160162
[project.scripts]

src/lerobot/envs/configs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,19 @@ def gym_kwargs(self) -> dict:
319319
"obs_type": self.obs_type,
320320
"render_mode": self.render_mode,
321321
}
322+
323+
@EnvConfig.register_subclass("gymnasium-robotics")
324+
@dataclass
325+
class GymRoboticsEnv(EnvConfig):
326+
# minimal fields the factory/CLI may expect
327+
type: str = "gymnasium-robotics"
328+
task: str = "fetch_pick_and_place"
329+
episode_length: int | None = None
330+
max_state_dim: int | None = None
331+
332+
seed: int | None = 0
333+
image_key: str = "agentview_image"
334+
335+
@property
336+
def gym_kwargs(self) -> dict:
337+
return {}

src/lerobot/envs/factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import gymnasium as gym
1919
from gymnasium.envs.registration import registry as gym_registry
2020

21-
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
21+
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, GymRoboticsEnv
2222

2323

2424
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -28,6 +28,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
2828
return PushtEnv(**kwargs)
2929
elif env_type == "libero":
3030
return LiberoEnv(**kwargs)
31+
elif env_type == "gymnasium-robotics":
32+
return GymRoboticsEnv(**kwargs)
3133
else:
3234
raise ValueError(f"Policy type '{env_type}' is not available.")
3335

@@ -85,6 +87,12 @@ def make_env(
8587
gym_kwargs=cfg.gym_kwargs,
8688
env_cls=env_cls,
8789
)
90+
elif "gymnasium-robotics" in cfg.type:
91+
from lerobot.envs.gymnasium_robotics import create_gymnasium_robotics_envs
92+
93+
if cfg.task is None:
94+
raise ValueError("Gym robotics requires a task to be specified")
95+
return create_gymnasium_robotics_envs(cfg)
8896

8997
if cfg.gym_id not in gym_registry:
9098
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import gymnasium_robotics
2+
import gymnasium as gym
3+
import numpy as np
4+
from typing import Dict
5+
from lerobot.envs.configs import GymRoboticsEnv
6+
7+
def create_gymnasium_robotics_envs(
8+
cfg: GymRoboticsEnv,
9+
n_envs: int = 1,
10+
use_async_envs: bool = False,
11+
) -> Dict[str, Dict[int, gym.vector.VectorEnv]]:
12+
"""
13+
Build vectorized GymRoboticsEnv(s) from the GymRoboticsEnv config and return:
14+
{ "<env_type>": { 0: <VectorEnv> } }
15+
Minimal and consistent with make_env(...) expected return type.
16+
"""
17+
# pull minimal fields from the config (with safe defaults)
18+
task = getattr(cfg, "task", "FetchPickAndPlace-v4")
19+
base_seed = getattr(cfg, "seed", 0)
20+
image_key = getattr(cfg, "image_key", "agentview_image")
21+
episode_length = getattr(cfg, "episode_length", None)
22+
max_state_dim = getattr(cfg, "max_state_dim", None)
23+
24+
# per-worker factory functions
25+
def _mk_one(worker_idx: int):
26+
def _ctor():
27+
seed = None if base_seed is None else int(base_seed) + worker_idx
28+
return GymRoboticsEnv(task=task, seed=seed, image_key=image_key, max_state_dim=max_state_dim, episode_length=episode_length)
29+
return _ctor
30+
31+
fns = [_mk_one(i) for i in range(n_envs)]
32+
vec_env = gym.vector.AsyncVectorEnv(fns) if use_async_envs else gym.vector.SyncVectorEnv(fns)
33+
34+
# key name kept simple/flat; matches your --env.type
35+
return {"gymnasium-robotics": {0: vec_env}}
36+
37+
class GymRoboticsEnv(gym.Env):
38+
"""Minimal adapter: wraps a Gymnasium-Robotics env and returns a LeRobot-style obs dict."""
39+
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
40+
41+
def __init__(
42+
self,
43+
task: str,
44+
seed: int | None = 0,
45+
image_key: str = "agentview_image",
46+
episode_length: int | None = None,
47+
max_state_dim: int | None = None,
48+
**make_kwargs
49+
):
50+
gym.register_envs(gymnasium_robotics)
51+
make_kwargs = dict(make_kwargs or {})
52+
make_kwargs["render_mode"] = "rgb_array"
53+
self.env = gym.make(task, max_episode_steps=1000, **make_kwargs)
54+
55+
self._rng = np.random.default_rng(seed)
56+
self._seed = seed
57+
self._image_key = image_key
58+
self._max_state_dim = max_state_dim
59+
60+
# action space: forward from underlying env
61+
self.action_space = self.env.action_space
62+
63+
# --- infer observation space once (do a temp reset+render) ---
64+
tmp_obs, _ = self.env.reset(seed=int(self._rng.integers(0, 2**31 - 1)) if seed is not None else None)
65+
frame = self.env.render()
66+
obs = self._to_obs(tmp_obs, frame)
67+
68+
# build observation_space to match o
69+
def _box_like(x, low=-np.inf, high=np.inf, dtype=np.float32):
70+
x = np.asarray(x)
71+
return gym.spaces.Box(low=low, high=high, shape=x.shape, dtype=dtype)
72+
73+
img = obs["images"][self._image_key]
74+
spaces = {
75+
"images": gym.spaces.Dict({self._image_key: gym.spaces.Box(low=0, high=255, shape=img.shape, dtype=np.uint8)}),
76+
"state": _box_like(obs["state"]),
77+
# NEW — aliases for libero-style preprocessors:
78+
"agent_pos": _box_like(obs["state"]),
79+
"pixels": gym.spaces.Box(low=0, high=255, shape=img.shape, dtype=np.uint8),
80+
}
81+
if "goal" in obs:
82+
spaces["goal"] = _box_like(obs["goal"])
83+
if "achieved_goal" in obs:
84+
spaces["achieved_goal"] = _box_like(obs["achieved_goal"])
85+
86+
self.observation_space = gym.spaces.Dict(spaces)
87+
# leave env in a valid state; vector wrapper will call reset() again later
88+
89+
# passthrough spec (if present on wrapped env)
90+
self.spec = getattr(self.env, "spec", None)
91+
92+
max_steps = episode_length
93+
if max_steps is None:
94+
# determine max episode steps for upstream code that reads _max_episode_steps
95+
max_steps = getattr(self.env, "_max_episode_steps", None)
96+
if max_steps is None and self.spec is not None:
97+
max_steps = getattr(self.spec, "max_episode_steps", None)
98+
99+
# try unwrapping one level if wrapped
100+
if max_steps is None and hasattr(self.env, "env"):
101+
inner = getattr(self.env, "env")
102+
max_steps = getattr(inner, "_max_episode_steps", None)
103+
if max_steps is None:
104+
inner_spec = getattr(inner, "spec", None)
105+
if inner_spec is not None:
106+
max_steps = getattr(inner_spec, "max_episode_steps", None)
107+
108+
# final fallback
109+
if max_steps is None:
110+
max_steps = 1000 # sensible default; adjust if you prefer
111+
112+
self._max_episode_steps = int(max_steps)
113+
114+
115+
def reset(self, seed: int | None = None, **kwargs):
116+
if seed is None and self._seed is not None:
117+
seed = int(self._rng.integers(0, 2**31 - 1))
118+
super().reset(seed=seed)
119+
tmp_obs, info = self.env.reset(seed=seed)
120+
frame = self.env.render()
121+
observation = self._to_obs(tmp_obs, frame)
122+
return observation, info
123+
124+
def step(self, action):
125+
if isinstance(self.action_space, gym.spaces.Box):
126+
action = np.clip(np.asarray(action, dtype=np.float32),
127+
self.action_space.low, self.action_space.high)
128+
tmp_obs, reward, terminated, truncated, info = self.env.step(action)
129+
frame = self.env.render()
130+
obs_out = self._to_obs(tmp_obs, frame)
131+
return obs_out, float(reward), bool(terminated), bool(truncated), info
132+
133+
def close(self):
134+
self.env.close()
135+
136+
def render(self):
137+
"""Return an RGB frame (HxWx3, uint8) like Gymnasium expects."""
138+
frame = self.env.render() # underlying env created with render_mode='rgb_array'
139+
if frame is None:
140+
raise RuntimeError("render() returned None; ensure render_mode='rgb_array' in make().")
141+
return frame.astype(np.uint8, copy=False)
142+
143+
# ---- helpers ----
144+
@staticmethod
145+
def _flat(x):
146+
if x is None: return np.zeros((0,), dtype=np.float32)
147+
return np.asarray(x, dtype=np.float32).reshape(-1)
148+
149+
def _to_obs(self, obs, frame):
150+
if isinstance(obs, dict):
151+
state = self._flat(obs.get("observation"))
152+
desired = obs.get("desired_goal")
153+
achieved = obs.get("achieved_goal")
154+
rgb = frame.astype(np.uint8, copy=False)
155+
elif isinstance(obs, np.ndarray) and obs.ndim == 3 and obs.shape[-1] in (1, 3):
156+
# Atari-style ndarray: treat as IMAGE, not state
157+
# use obs as the frame if frame is None
158+
rgb_src = frame if frame is not None else obs
159+
rgb = rgb_src.astype(np.uint8, copy=False)
160+
# no structured state in Atari pixels; provide empty state vector
161+
state = np.empty((0,), dtype=np.float32)
162+
desired = achieved = None
163+
else:
164+
# fallback: unknown non-dict obs → treat as state only
165+
state = self._flat(obs)
166+
if self._max_state_dim is not None and len(state) > self._max_state_dim:
167+
state = state[:self._max_state_dim]
168+
desired = achieved = None
169+
rgb = frame.astype(np.uint8, copy=False)
170+
171+
rgb = frame.astype(np.uint8, copy=False)
172+
173+
out = {
174+
# gym original keys
175+
"images": {self._image_key: rgb},
176+
"state": state,
177+
# aliases expected by LeRobot preprocessors
178+
"agent_pos": state, # alias for state
179+
"pixels": rgb, # alias for a single RGB view
180+
}
181+
if desired is not None: out["goal"] = self._flat(desired)
182+
if achieved is not None: out["achieved_goal"] = self._flat(achieved)
183+
return out

0 commit comments

Comments
 (0)