Skip to content

Commit 2123ce8

Browse files
vmoenscursoragent
andcommitted
[Feature] Add RayEvalWorker for async evaluation (#3474)
Add a new torchrl.eval module with RayEvalWorker, a Ray-backed async evaluation helper that runs environment + policy in a separate process. This is useful when the evaluation environment requires special process-level initialisation (e.g. Isaac Lab AppLauncher) or when evaluation should happen concurrently with training on a separate GPU. API: submit(weights, max_steps) / poll() -> {reward, frames}. Co-authored-by: Cursor <[email protected]> ghstack-source-id: f3574c2 Pull-Request: #3474 Co-authored-by: Cursor <[email protected]>
1 parent bde59db commit 2123ce8

File tree

3 files changed

+355
-0
lines changed

3 files changed

+355
-0
lines changed

test/test_libs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5822,6 +5822,66 @@ def test_isaaclab_lstm(self, env):
58225822
assert ("next", "recurrent_state_c") in rollout.keys(True)
58235823

58245824

5825+
@pytest.mark.skipif(
5826+
not _has_ray or not _has_gymnasium, reason="Ray or Gymnasium not found"
5827+
)
5828+
class TestRayEvalWorker:
5829+
"""Tests for the RayEvalWorker async evaluation helper."""
5830+
5831+
@pytest.fixture(autouse=True)
5832+
def _setup_ray(self):
5833+
import ray
5834+
5835+
ray.init(ignore_reinit_error=True, num_gpus=0)
5836+
yield
5837+
ray.shutdown()
5838+
5839+
def test_ray_eval_worker_basic(self):
5840+
"""Test submit/poll cycle with a simple environment."""
5841+
import torch.nn as nn
5842+
5843+
from tensordict import TensorDict
5844+
from tensordict.nn import TensorDictModule
5845+
from torchrl.collectors.distributed import RayEvalWorker
5846+
5847+
from torchrl.envs import GymEnv, StepCounter, TransformedEnv
5848+
5849+
def make_env():
5850+
return TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(10))
5851+
5852+
def make_policy(env):
5853+
action_dim = env.action_spec.shape[-1]
5854+
obs_dim = env.observation_spec["observation"].shape[-1]
5855+
return TensorDictModule(
5856+
nn.Linear(obs_dim, action_dim),
5857+
in_keys=["observation"],
5858+
out_keys=["action"],
5859+
)
5860+
5861+
worker = RayEvalWorker(
5862+
init_fn=None,
5863+
env_maker=make_env,
5864+
policy_maker=make_policy,
5865+
num_gpus=0,
5866+
)
5867+
try:
5868+
# Before submit, poll returns None
5869+
assert worker.poll() is None
5870+
5871+
weights = (
5872+
TensorDict.from_module(make_policy(make_env())).data.detach().cpu()
5873+
)
5874+
worker.submit(weights, max_steps=5)
5875+
5876+
# Wait for result (blocking poll)
5877+
result = worker.poll(timeout=30)
5878+
assert result is not None
5879+
assert "reward" in result
5880+
assert "frames" in result
5881+
finally:
5882+
worker.shutdown()
5883+
5884+
58255885
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
58265886
class TestProcgen:
58275887
@pytest.mark.parametrize("envname", ["coinrun", "starpilot"])

torchrl/collectors/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DistributedWeightUpdater,
1111
)
1212
from .ray import RayCollector
13+
from .ray_eval_worker import RayEvalWorker
1314
from .rpc import RPCCollector, RPCDataCollector, RPCWeightUpdater
1415
from .sync import DistributedSyncCollector, DistributedSyncDataCollector
1516
from .utils import submitit_delayed_launcher
@@ -28,5 +29,6 @@
2829
"DistributedWeightUpdater",
2930
"RPCWeightUpdater",
3031
"RayCollector",
32+
"RayEvalWorker",
3133
"submitit_delayed_launcher",
3234
]
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Ray-based asynchronous evaluation worker.
6+
7+
This module provides :class:`RayEvalWorker`, a generic helper that runs an
8+
environment and policy inside a dedicated Ray actor process. This is useful
9+
when the evaluation environment requires special process-level initialisation
10+
(e.g. Isaac Lab's ``AppLauncher`` must run before ``import torch``) or when
11+
evaluation should happen concurrently with training on a separate GPU.
12+
13+
Typical usage::
14+
15+
from torchrl.collectors.distributed import RayEvalWorker
16+
17+
worker = RayEvalWorker(
18+
init_fn=my_init, # called first in the actor process
19+
env_maker=make_eval_env, # returns a TorchRL env
20+
policy_maker=make_policy, # returns a TorchRL policy module
21+
num_gpus=1,
22+
)
23+
24+
# Non-blocking: submit weights and start a rollout
25+
weights = TensorDict.from_module(policy).data.detach().cpu()
26+
worker.submit(weights, max_steps=500)
27+
28+
# Later – check if the rollout finished
29+
result = worker.poll() # None while still running
30+
if result is not None:
31+
print(result["reward"]) # scalar mean episode reward
32+
print(result["frames"]) # (T, H, W, 3) uint8 tensor or None
33+
"""
34+
from __future__ import annotations
35+
36+
import logging
37+
from collections.abc import Callable
38+
from typing import Any
39+
40+
logger = logging.getLogger(__name__)
41+
42+
43+
class RayEvalWorker:
44+
"""Asynchronous evaluation worker backed by a Ray actor.
45+
46+
The worker creates a **new Python process** (via Ray) and inside it:
47+
48+
1. Calls *init_fn* -- use this for any process-level setup that must happen
49+
before other imports (e.g. Isaac Lab ``AppLauncher``).
50+
2. Creates the environment via *env_maker*.
51+
3. Creates the policy via *policy_maker(env)*.
52+
53+
Thereafter, :meth:`submit` sends new policy weights and triggers an
54+
evaluation rollout. :meth:`poll` returns the result (reward and optional
55+
video frames) when the rollout finishes, or ``None`` if it is still
56+
running.
57+
58+
Args:
59+
init_fn: Optional callable invoked at the very start of the actor
60+
process, before *env_maker* or *policy_maker*. All imports should
61+
be **local** inside this callable so that the actor's fresh Python
62+
process can control import order. Set to ``None`` to skip.
63+
env_maker: Callable that returns a TorchRL environment. Called once
64+
inside the actor after *init_fn*. If the underlying environment
65+
supports ``render_mode="rgb_array"``, the actor will call
66+
``render()`` on each evaluation step and return the frames.
67+
policy_maker: Callable ``(env) -> policy`` that builds the policy
68+
module given the environment. Called once inside the actor after
69+
the environment has been created.
70+
num_gpus: Number of GPUs to request from Ray for this actor.
71+
Defaults to 1.
72+
reward_keys: Nested key(s) used to read the reward from the rollout
73+
tensordict. Defaults to ``("next", "reward")``.
74+
**remote_kwargs: Extra keyword arguments forwarded to
75+
``ray.remote()`` when creating the actor class (e.g.
76+
``num_cpus``, ``runtime_env``).
77+
"""
78+
79+
def __init__(
80+
self,
81+
init_fn: Callable[[], None] | None,
82+
env_maker: Callable[[], Any],
83+
policy_maker: Callable[[Any], Any],
84+
*,
85+
num_gpus: int = 1,
86+
reward_keys: tuple[str, ...] = ("next", "reward"),
87+
**remote_kwargs: Any,
88+
) -> None:
89+
import ray
90+
91+
self._reward_keys = reward_keys
92+
93+
# Build the remote actor class dynamically so that the caller does not
94+
# need to depend on Ray at import time.
95+
actor_cls = ray.remote(num_gpus=num_gpus, **remote_kwargs)(_EvalActor)
96+
97+
self._actor = actor_cls.remote(init_fn, env_maker, policy_maker)
98+
self._pending_ref: ray.ObjectRef | None = None
99+
100+
# ------------------------------------------------------------------
101+
# Public API
102+
# ------------------------------------------------------------------
103+
104+
def submit(
105+
self,
106+
weights: Any,
107+
max_steps: int,
108+
*,
109+
deterministic: bool = True,
110+
break_when_any_done: bool = True,
111+
) -> None:
112+
"""Start an asynchronous evaluation rollout.
113+
114+
If a previous rollout is still running its result is silently
115+
discarded (fire-and-forget semantics).
116+
117+
Args:
118+
weights: Policy weights, typically obtained via
119+
``TensorDict.from_module(policy).data.detach().cpu()``.
120+
max_steps: Maximum number of environment steps per rollout.
121+
deterministic: If ``True``, use deterministic exploration.
122+
break_when_any_done: If ``True``, stop the rollout as soon as
123+
any sub-environment reports ``done``.
124+
"""
125+
# Discard any previous un-polled result
126+
self._pending_ref = self._actor.eval.remote(
127+
weights,
128+
max_steps,
129+
self._reward_keys,
130+
deterministic,
131+
break_when_any_done,
132+
)
133+
134+
def poll(self, timeout: float = 0) -> dict | None:
135+
"""Return the evaluation result if ready, otherwise ``None``.
136+
137+
The returned dict contains:
138+
139+
- ``"reward"`` -- scalar mean episode reward.
140+
- ``"frames"`` -- ``(T, H, W, 3)`` uint8 CPU tensor of rendered
141+
frames, or ``None`` if the environment does not render.
142+
143+
Args:
144+
timeout: Seconds to wait for the result. ``0`` means
145+
non-blocking (return immediately if not ready).
146+
"""
147+
if self._pending_ref is None:
148+
return None
149+
150+
import ray
151+
152+
ready, _ = ray.wait([self._pending_ref], timeout=timeout)
153+
if not ready:
154+
return None
155+
156+
result = ray.get(self._pending_ref)
157+
self._pending_ref = None
158+
return result
159+
160+
def shutdown(self) -> None:
161+
"""Close the environment and kill the actor."""
162+
import ray
163+
164+
try:
165+
ray.get(self._actor.shutdown.remote())
166+
except Exception:
167+
logger.warning("RayEvalWorker: error during shutdown", exc_info=True)
168+
ray.kill(self._actor)
169+
self._actor = None
170+
self._pending_ref = None
171+
172+
173+
# ======================================================================
174+
# Inner actor -- runs inside the Ray worker process
175+
# ======================================================================
176+
177+
178+
class _EvalActor:
179+
"""Plain class turned into a Ray actor by :class:`RayEvalWorker`.
180+
181+
All heavy imports happen inside methods so that the module-level import
182+
of this file does **not** pull in torch, torchrl, or any simulator SDK.
183+
"""
184+
185+
def __init__(
186+
self,
187+
init_fn: Callable[[], None] | None,
188+
env_maker: Callable[[], Any],
189+
policy_maker: Callable[[Any], Any],
190+
) -> None:
191+
# --- process-level initialisation (e.g. AppLauncher) ---
192+
if init_fn is not None:
193+
init_fn()
194+
195+
# --- now safe to import torch / torchrl ---
196+
import torch # noqa: F401
197+
198+
self.env = env_maker()
199+
self.policy = policy_maker(self.env)
200+
# Cache device before any to_module call can replace nn.Parameter
201+
# with plain tensors (which makes .parameters() empty).
202+
self._device = next(self.policy.parameters()).device
203+
204+
def eval(
205+
self,
206+
weights,
207+
max_steps: int,
208+
reward_keys: tuple[str, ...],
209+
deterministic: bool,
210+
break_when_any_done: bool,
211+
) -> dict:
212+
"""Run an evaluation rollout with the given weights."""
213+
import torch
214+
215+
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
216+
217+
# Load weights into the eval policy (move to policy device first)
218+
weights.to(self._device).to_module(self.policy)
219+
220+
frames = []
221+
total_reward = 0.0
222+
num_steps = 0
223+
224+
exploration = (
225+
ExplorationType.DETERMINISTIC if deterministic else ExplorationType.RANDOM
226+
)
227+
with set_exploration_type(exploration), torch.no_grad():
228+
td = self.env.reset()
229+
for _i in range(max_steps):
230+
td = self.policy(td)
231+
td = self.env.step(td)
232+
233+
total_reward += td[reward_keys].mean().item()
234+
num_steps += 1
235+
236+
frame = self._try_render()
237+
if frame is not None:
238+
frames.append(frame)
239+
240+
done = td.get(("next", "done"), None)
241+
if break_when_any_done and done is not None and done.any():
242+
break
243+
244+
td = step_mdp(td)
245+
246+
mean_reward = total_reward / max(1, num_steps)
247+
248+
# Format video: (1, T, C, H, W) uint8 CPU tensor
249+
video = None
250+
if frames:
251+
video = torch.stack(frames, dim=0).unsqueeze(0).cpu()
252+
253+
return {"reward": mean_reward, "frames": video}
254+
255+
def _try_render(self):
256+
"""Render one frame from the underlying environment.
257+
258+
Walks the wrapper chain to find a callable ``render()`` method
259+
and returns the result as a ``(C, H, W)`` uint8 tensor, or
260+
``None`` if rendering is unavailable.
261+
"""
262+
import numpy as np
263+
import torch
264+
265+
# Walk through TransformedEnv / wrapper chain to the base env.
266+
env = self.env
267+
while hasattr(env, "base_env"):
268+
env = env.base_env
269+
render_fn = getattr(env, "render", None)
270+
# If the base env delegates to a gymnasium env, prefer that.
271+
if hasattr(env, "_env") and hasattr(env._env, "render"):
272+
render_fn = env._env.render
273+
if render_fn is None:
274+
return None
275+
276+
raw = render_fn()
277+
if raw is None:
278+
return None
279+
280+
if isinstance(raw, np.ndarray):
281+
raw = torch.from_numpy(raw.copy())
282+
283+
# (H, W, C) -> (C, H, W)
284+
if raw.ndim == 3 and raw.shape[-1] in (3, 4):
285+
raw = raw[..., :3]
286+
raw = raw.permute(2, 0, 1)
287+
288+
return raw.to(torch.uint8)
289+
290+
def shutdown(self) -> None:
291+
"""Shut down the environment."""
292+
if hasattr(self, "env") and not self.env.is_closed:
293+
self.env.close()

0 commit comments

Comments
 (0)