Skip to content

Commit a9b66c6

Browse files
committed
Add gym-robotics
1 parent 1ff8986 commit a9b66c6

File tree

5 files changed

+201
-1
lines changed

5 files changed

+201
-1
lines changed

examples/smoke_gym_robotics.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from lerobot.envs.gym_robotics import GymRoboticsEnv
2+
import numpy as np
3+
4+
env = GymRoboticsEnv("FetchPickAndPlace-v4")
5+
obs, info = env.reset()
6+
print({k: type(v) for k, v in obs.items()})
7+
print({k: v.shape for k, v in obs["images"].items()})
8+
print("state shape:", obs["state"].shape)
9+
print("goal in obs:", "goal" in obs)
10+
print(env.action_space)
11+
print(env.action_space.shape[0])
12+
13+
done = False
14+
while not done:
15+
action = np.zeros(env.action_space.shape, dtype=np.float32)
16+
obs, reward, terminated, truncated, info = env.step(action)
17+
done = terminated or truncated
18+
print("rollout ok")
19+
env.close()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ aloha = ["gym-aloha>=0.1.2,<0.2.0"]
133133
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
134134
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
135135
metaworld = ["metaworld>=3.0.0"]
136+
gym-robotics = ["gymnasium-robotics>=1.4.1"]
136137

137138
# All
138139
all = [
@@ -155,6 +156,7 @@ all = [
155156
"lerobot[phone]",
156157
"lerobot[libero]",
157158
"lerobot[metaworld]",
159+
"lerobot[gym-robotics]",
158160
]
159161

160162
[project.scripts]

src/lerobot/envs/configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,17 @@ def gym_kwargs(self) -> dict:
309309
"obs_type": self.obs_type,
310310
"render_mode": self.render_mode,
311311
}
312+
313+
@EnvConfig.register_subclass("gym-robotics")
314+
@dataclass
315+
class GymRoboticsEnv(EnvConfig):
316+
# minimal fields the factory/CLI may expect
317+
type: str = "gym-robotics"
318+
task: str = "fetch_pick_and_place"
319+
320+
seed: int | None = 0
321+
image_key: str = "agentview_image"
322+
323+
@property
324+
def gym_kwargs(self) -> dict:
325+
return {}

src/lerobot/envs/factory.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import gymnasium as gym
1919

20-
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
20+
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, GymRoboticsEnv
2121

2222

2323
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
2727
return PushtEnv(**kwargs)
2828
elif env_type == "libero":
2929
return LiberoEnv(**kwargs)
30+
elif env_type == "gym-robotics":
31+
return GymRoboticsEnv(**kwargs)
3032
else:
3133
raise ValueError(f"Policy type '{env_type}' is not available.")
3234

@@ -84,6 +86,13 @@ def make_env(
8486
gym_kwargs=cfg.gym_kwargs,
8587
env_cls=env_cls,
8688
)
89+
elif "gym-robotics" in cfg.type:
90+
from lerobot.envs.gym_robotics import create_gym_robotics_envs
91+
92+
if cfg.task is None:
93+
raise ValueError("Gym robotics requires a task to be specified")
94+
return create_gym_robotics_envs(cfg)
95+
8796
package_name = f"gym_{cfg.type}"
8897
try:
8998
importlib.import_module(package_name)

src/lerobot/envs/gym_robotics.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import gymnasium_robotics
2+
import gymnasium as gym
3+
import numpy as np
4+
from typing import Dict
5+
from lerobot.envs.configs import GymRoboticsEnv
6+
7+
def create_gym_robotics_envs(
8+
cfg: GymRoboticsEnv,
9+
n_envs: int = 1,
10+
use_async_envs: bool = False,
11+
) -> Dict[str, Dict[int, gym.vector.VectorEnv]]:
12+
"""
13+
Build vectorized GymRoboticsEnv(s) from the GymRoboticsEnv config and return:
14+
{ "<env_type>": { 0: <VectorEnv> } }
15+
Minimal and consistent with make_env(...) expected return type.
16+
"""
17+
# pull minimal fields from the config (with safe defaults)
18+
task = getattr(cfg, "task", "FetchPickAndPlace-v4")
19+
base_seed = getattr(cfg, "seed", 0)
20+
image_key = getattr(cfg, "image_key", "agentview_image")
21+
22+
# per-worker factory functions
23+
def _mk_one(worker_idx: int):
24+
def _ctor():
25+
seed = None if base_seed is None else int(base_seed) + worker_idx
26+
return GymRoboticsEnv(task=task, seed=seed, image_key=image_key)
27+
return _ctor
28+
29+
fns = [_mk_one(i) for i in range(n_envs)]
30+
vec_env = gym.vector.AsyncVectorEnv(fns) if use_async_envs else gym.vector.SyncVectorEnv(fns)
31+
32+
# key name kept simple/flat; matches your --env.type
33+
return {"gym-robotics": {0: vec_env}}
34+
35+
class GymRoboticsEnv(gym.Env):
36+
"""Minimal adapter: wraps a Gymnasium-Robotics env and returns a LeRobot-style obs dict."""
37+
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
38+
39+
def __init__(self, task: str, seed: int | None = 0, image_key: str = "agentview_image", **make_kwargs):
40+
gym.register_envs(gymnasium_robotics)
41+
make_kwargs = dict(make_kwargs or {})
42+
make_kwargs["render_mode"] = "rgb_array"
43+
self.env = gym.make(task, **make_kwargs)
44+
45+
self._rng = np.random.default_rng(seed)
46+
self._seed = seed
47+
self._image_key = image_key
48+
49+
# action space: forward from underlying env
50+
self.action_space = self.env.action_space
51+
52+
# --- infer observation space once (do a temp reset+render) ---
53+
tmp_obs, _ = self.env.reset(seed=int(self._rng.integers(0, 2**31 - 1)) if seed is not None else None)
54+
frame = self.env.render()
55+
obs = self._to_obs(tmp_obs, frame)
56+
57+
# build observation_space to match o
58+
def _box_like(x, low=-np.inf, high=np.inf, dtype=np.float32):
59+
x = np.asarray(x)
60+
return gym.spaces.Box(low=low, high=high, shape=x.shape, dtype=dtype)
61+
62+
img = obs["images"][self._image_key]
63+
spaces = {
64+
"images": gym.spaces.Dict({self._image_key: gym.spaces.Box(low=0, high=255, shape=img.shape, dtype=np.uint8)}),
65+
"state": _box_like(obs["state"]),
66+
# NEW — aliases for libero-style preprocessors:
67+
"agent_pos": _box_like(obs["state"]),
68+
"pixels": gym.spaces.Box(low=0, high=255, shape=img.shape, dtype=np.uint8),
69+
}
70+
if "goal" in obs:
71+
spaces["goal"] = _box_like(obs["goal"])
72+
if "achieved_goal" in obs:
73+
spaces["achieved_goal"] = _box_like(obs["achieved_goal"])
74+
75+
self.observation_space = gym.spaces.Dict(spaces)
76+
# leave env in a valid state; vector wrapper will call reset() again later
77+
78+
# passthrough spec (if present on wrapped env)
79+
self.spec = getattr(self.env, "spec", None)
80+
81+
# determine max episode steps for upstream code that reads _max_episode_steps
82+
max_steps = getattr(self.env, "_max_episode_steps", None)
83+
if max_steps is None and self.spec is not None:
84+
max_steps = getattr(self.spec, "max_episode_steps", None)
85+
86+
# try unwrapping one level if wrapped
87+
if max_steps is None and hasattr(self.env, "env"):
88+
inner = getattr(self.env, "env")
89+
max_steps = getattr(inner, "_max_episode_steps", None)
90+
if max_steps is None:
91+
inner_spec = getattr(inner, "spec", None)
92+
if inner_spec is not None:
93+
max_steps = getattr(inner_spec, "max_episode_steps", None)
94+
95+
# final fallback
96+
if max_steps is None:
97+
max_steps = 1000 # sensible default; adjust if you prefer
98+
99+
self._max_episode_steps = int(max_steps)
100+
101+
102+
def reset(self, seed: int | None = None, **kwargs):
103+
if seed is None and self._seed is not None:
104+
seed = int(self._rng.integers(0, 2**31 - 1))
105+
super().reset(seed=seed)
106+
tmp_obs, info = self.env.reset(seed=seed)
107+
frame = self.env.render()
108+
observation = self._to_obs(tmp_obs, frame)
109+
return observation, info
110+
111+
def step(self, action):
112+
if isinstance(self.action_space, gym.spaces.Box):
113+
action = np.clip(np.asarray(action, dtype=np.float32),
114+
self.action_space.low, self.action_space.high)
115+
tmp_obs, reward, terminated, truncated, info = self.env.step(action)
116+
frame = self.env.render()
117+
obs_out = self._to_obs(tmp_obs, frame)
118+
return obs_out, float(reward), bool(terminated), bool(truncated), info
119+
120+
def close(self):
121+
self.env.close()
122+
123+
def render(self):
124+
"""Return an RGB frame (HxWx3, uint8) like Gymnasium expects."""
125+
frame = self.env.render() # underlying env created with render_mode='rgb_array'
126+
if frame is None:
127+
raise RuntimeError("render() returned None; ensure render_mode='rgb_array' in make().")
128+
return frame.astype(np.uint8, copy=False)
129+
130+
# ---- helpers ----
131+
@staticmethod
132+
def _flat(x):
133+
if x is None: return np.zeros((0,), dtype=np.float32)
134+
return np.asarray(x, dtype=np.float32).reshape(-1)
135+
136+
def _to_obs(self, obs, frame):
137+
if isinstance(obs, dict):
138+
state = self._flat(obs.get("observation"))
139+
desired = obs.get("desired_goal")
140+
achieved = obs.get("achieved_goal")
141+
else:
142+
state = self._flat(obs); desired = achieved = None
143+
144+
rgb = frame.astype(np.uint8, copy=False)
145+
146+
out = {
147+
# gym original keys
148+
"images": {self._image_key: rgb},
149+
"state": state,
150+
# aliases expected by LeRobot preprocessors
151+
"agent_pos": state, # alias for state
152+
"pixels": rgb, # alias for a single RGB view
153+
}
154+
if desired is not None: out["goal"] = self._flat(desired)
155+
if achieved is not None: out["achieved_goal"] = self._flat(achieved)
156+
return out

0 commit comments

Comments
 (0)