|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +"""Ray-based asynchronous evaluation worker. |
| 6 | +
|
| 7 | +This module provides :class:`RayEvalWorker`, a generic helper that runs an |
| 8 | +environment and policy inside a dedicated Ray actor process. This is useful |
| 9 | +when the evaluation environment requires special process-level initialisation |
| 10 | +(e.g. Isaac Lab's ``AppLauncher`` must run before ``import torch``) or when |
| 11 | +evaluation should happen concurrently with training on a separate GPU. |
| 12 | +
|
| 13 | +Typical usage:: |
| 14 | +
|
| 15 | + from torchrl.collectors.distributed import RayEvalWorker |
| 16 | +
|
| 17 | + worker = RayEvalWorker( |
| 18 | + init_fn=my_init, # called first in the actor process |
| 19 | + env_maker=make_eval_env, # returns a TorchRL env |
| 20 | + policy_maker=make_policy, # returns a TorchRL policy module |
| 21 | + num_gpus=1, |
| 22 | + ) |
| 23 | +
|
| 24 | + # Non-blocking: submit weights and start a rollout |
| 25 | + weights = TensorDict.from_module(policy).data.detach().cpu() |
| 26 | + worker.submit(weights, max_steps=500) |
| 27 | +
|
| 28 | + # Later – check if the rollout finished |
| 29 | + result = worker.poll() # None while still running |
| 30 | + if result is not None: |
| 31 | + print(result["reward"]) # scalar mean episode reward |
| 32 | + print(result["frames"]) # (T, H, W, 3) uint8 tensor or None |
| 33 | +""" |
| 34 | +from __future__ import annotations |
| 35 | + |
| 36 | +import logging |
| 37 | +from collections.abc import Callable |
| 38 | +from typing import Any |
| 39 | + |
| 40 | +logger = logging.getLogger(__name__) |
| 41 | + |
| 42 | + |
| 43 | +class RayEvalWorker: |
| 44 | + """Asynchronous evaluation worker backed by a Ray actor. |
| 45 | +
|
| 46 | + The worker creates a **new Python process** (via Ray) and inside it: |
| 47 | +
|
| 48 | + 1. Calls *init_fn* -- use this for any process-level setup that must happen |
| 49 | + before other imports (e.g. Isaac Lab ``AppLauncher``). |
| 50 | + 2. Creates the environment via *env_maker*. |
| 51 | + 3. Creates the policy via *policy_maker(env)*. |
| 52 | +
|
| 53 | + Thereafter, :meth:`submit` sends new policy weights and triggers an |
| 54 | + evaluation rollout. :meth:`poll` returns the result (reward and optional |
| 55 | + video frames) when the rollout finishes, or ``None`` if it is still |
| 56 | + running. |
| 57 | +
|
| 58 | + Args: |
| 59 | + init_fn: Optional callable invoked at the very start of the actor |
| 60 | + process, before *env_maker* or *policy_maker*. All imports should |
| 61 | + be **local** inside this callable so that the actor's fresh Python |
| 62 | + process can control import order. Set to ``None`` to skip. |
| 63 | + env_maker: Callable that returns a TorchRL environment. Called once |
| 64 | + inside the actor after *init_fn*. If the underlying environment |
| 65 | + supports ``render_mode="rgb_array"``, the actor will call |
| 66 | + ``render()`` on each evaluation step and return the frames. |
| 67 | + policy_maker: Callable ``(env) -> policy`` that builds the policy |
| 68 | + module given the environment. Called once inside the actor after |
| 69 | + the environment has been created. |
| 70 | + num_gpus: Number of GPUs to request from Ray for this actor. |
| 71 | + Defaults to 1. |
| 72 | + reward_keys: Nested key(s) used to read the reward from the rollout |
| 73 | + tensordict. Defaults to ``("next", "reward")``. |
| 74 | + **remote_kwargs: Extra keyword arguments forwarded to |
| 75 | + ``ray.remote()`` when creating the actor class (e.g. |
| 76 | + ``num_cpus``, ``runtime_env``). |
| 77 | + """ |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + init_fn: Callable[[], None] | None, |
| 82 | + env_maker: Callable[[], Any], |
| 83 | + policy_maker: Callable[[Any], Any], |
| 84 | + *, |
| 85 | + num_gpus: int = 1, |
| 86 | + reward_keys: tuple[str, ...] = ("next", "reward"), |
| 87 | + **remote_kwargs: Any, |
| 88 | + ) -> None: |
| 89 | + import ray |
| 90 | + |
| 91 | + self._reward_keys = reward_keys |
| 92 | + |
| 93 | + # Build the remote actor class dynamically so that the caller does not |
| 94 | + # need to depend on Ray at import time. |
| 95 | + actor_cls = ray.remote(num_gpus=num_gpus, **remote_kwargs)(_EvalActor) |
| 96 | + |
| 97 | + self._actor = actor_cls.remote(init_fn, env_maker, policy_maker) |
| 98 | + self._pending_ref: ray.ObjectRef | None = None |
| 99 | + |
| 100 | + # ------------------------------------------------------------------ |
| 101 | + # Public API |
| 102 | + # ------------------------------------------------------------------ |
| 103 | + |
| 104 | + def submit( |
| 105 | + self, |
| 106 | + weights: Any, |
| 107 | + max_steps: int, |
| 108 | + *, |
| 109 | + deterministic: bool = True, |
| 110 | + break_when_any_done: bool = True, |
| 111 | + ) -> None: |
| 112 | + """Start an asynchronous evaluation rollout. |
| 113 | +
|
| 114 | + If a previous rollout is still running its result is silently |
| 115 | + discarded (fire-and-forget semantics). |
| 116 | +
|
| 117 | + Args: |
| 118 | + weights: Policy weights, typically obtained via |
| 119 | + ``TensorDict.from_module(policy).data.detach().cpu()``. |
| 120 | + max_steps: Maximum number of environment steps per rollout. |
| 121 | + deterministic: If ``True``, use deterministic exploration. |
| 122 | + break_when_any_done: If ``True``, stop the rollout as soon as |
| 123 | + any sub-environment reports ``done``. |
| 124 | + """ |
| 125 | + # Discard any previous un-polled result |
| 126 | + self._pending_ref = self._actor.eval.remote( |
| 127 | + weights, |
| 128 | + max_steps, |
| 129 | + self._reward_keys, |
| 130 | + deterministic, |
| 131 | + break_when_any_done, |
| 132 | + ) |
| 133 | + |
| 134 | + def poll(self, timeout: float = 0) -> dict | None: |
| 135 | + """Return the evaluation result if ready, otherwise ``None``. |
| 136 | +
|
| 137 | + The returned dict contains: |
| 138 | +
|
| 139 | + - ``"reward"`` -- scalar mean episode reward. |
| 140 | + - ``"frames"`` -- ``(T, H, W, 3)`` uint8 CPU tensor of rendered |
| 141 | + frames, or ``None`` if the environment does not render. |
| 142 | +
|
| 143 | + Args: |
| 144 | + timeout: Seconds to wait for the result. ``0`` means |
| 145 | + non-blocking (return immediately if not ready). |
| 146 | + """ |
| 147 | + if self._pending_ref is None: |
| 148 | + return None |
| 149 | + |
| 150 | + import ray |
| 151 | + |
| 152 | + ready, _ = ray.wait([self._pending_ref], timeout=timeout) |
| 153 | + if not ready: |
| 154 | + return None |
| 155 | + |
| 156 | + result = ray.get(self._pending_ref) |
| 157 | + self._pending_ref = None |
| 158 | + return result |
| 159 | + |
| 160 | + def shutdown(self) -> None: |
| 161 | + """Close the environment and kill the actor.""" |
| 162 | + import ray |
| 163 | + |
| 164 | + try: |
| 165 | + ray.get(self._actor.shutdown.remote()) |
| 166 | + except Exception: |
| 167 | + logger.warning("RayEvalWorker: error during shutdown", exc_info=True) |
| 168 | + ray.kill(self._actor) |
| 169 | + self._actor = None |
| 170 | + self._pending_ref = None |
| 171 | + |
| 172 | + |
| 173 | +# ====================================================================== |
| 174 | +# Inner actor -- runs inside the Ray worker process |
| 175 | +# ====================================================================== |
| 176 | + |
| 177 | + |
| 178 | +class _EvalActor: |
| 179 | + """Plain class turned into a Ray actor by :class:`RayEvalWorker`. |
| 180 | +
|
| 181 | + All heavy imports happen inside methods so that the module-level import |
| 182 | + of this file does **not** pull in torch, torchrl, or any simulator SDK. |
| 183 | + """ |
| 184 | + |
| 185 | + def __init__( |
| 186 | + self, |
| 187 | + init_fn: Callable[[], None] | None, |
| 188 | + env_maker: Callable[[], Any], |
| 189 | + policy_maker: Callable[[Any], Any], |
| 190 | + ) -> None: |
| 191 | + # --- process-level initialisation (e.g. AppLauncher) --- |
| 192 | + if init_fn is not None: |
| 193 | + init_fn() |
| 194 | + |
| 195 | + # --- now safe to import torch / torchrl --- |
| 196 | + import torch # noqa: F401 |
| 197 | + |
| 198 | + self.env = env_maker() |
| 199 | + self.policy = policy_maker(self.env) |
| 200 | + # Cache device before any to_module call can replace nn.Parameter |
| 201 | + # with plain tensors (which makes .parameters() empty). |
| 202 | + self._device = next(self.policy.parameters()).device |
| 203 | + |
| 204 | + def eval( |
| 205 | + self, |
| 206 | + weights, |
| 207 | + max_steps: int, |
| 208 | + reward_keys: tuple[str, ...], |
| 209 | + deterministic: bool, |
| 210 | + break_when_any_done: bool, |
| 211 | + ) -> dict: |
| 212 | + """Run an evaluation rollout with the given weights.""" |
| 213 | + import torch |
| 214 | + |
| 215 | + from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp |
| 216 | + |
| 217 | + # Load weights into the eval policy (move to policy device first) |
| 218 | + weights.to(self._device).to_module(self.policy) |
| 219 | + |
| 220 | + frames = [] |
| 221 | + total_reward = 0.0 |
| 222 | + num_steps = 0 |
| 223 | + |
| 224 | + exploration = ( |
| 225 | + ExplorationType.DETERMINISTIC if deterministic else ExplorationType.RANDOM |
| 226 | + ) |
| 227 | + with set_exploration_type(exploration), torch.no_grad(): |
| 228 | + td = self.env.reset() |
| 229 | + for _i in range(max_steps): |
| 230 | + td = self.policy(td) |
| 231 | + td = self.env.step(td) |
| 232 | + |
| 233 | + total_reward += td[reward_keys].mean().item() |
| 234 | + num_steps += 1 |
| 235 | + |
| 236 | + frame = self._try_render() |
| 237 | + if frame is not None: |
| 238 | + frames.append(frame) |
| 239 | + |
| 240 | + done = td.get(("next", "done"), None) |
| 241 | + if break_when_any_done and done is not None and done.any(): |
| 242 | + break |
| 243 | + |
| 244 | + td = step_mdp(td) |
| 245 | + |
| 246 | + mean_reward = total_reward / max(1, num_steps) |
| 247 | + |
| 248 | + # Format video: (1, T, C, H, W) uint8 CPU tensor |
| 249 | + video = None |
| 250 | + if frames: |
| 251 | + video = torch.stack(frames, dim=0).unsqueeze(0).cpu() |
| 252 | + |
| 253 | + return {"reward": mean_reward, "frames": video} |
| 254 | + |
| 255 | + def _try_render(self): |
| 256 | + """Render one frame from the underlying environment. |
| 257 | +
|
| 258 | + Walks the wrapper chain to find a callable ``render()`` method |
| 259 | + and returns the result as a ``(C, H, W)`` uint8 tensor, or |
| 260 | + ``None`` if rendering is unavailable. |
| 261 | + """ |
| 262 | + import numpy as np |
| 263 | + import torch |
| 264 | + |
| 265 | + # Walk through TransformedEnv / wrapper chain to the base env. |
| 266 | + env = self.env |
| 267 | + while hasattr(env, "base_env"): |
| 268 | + env = env.base_env |
| 269 | + render_fn = getattr(env, "render", None) |
| 270 | + # If the base env delegates to a gymnasium env, prefer that. |
| 271 | + if hasattr(env, "_env") and hasattr(env._env, "render"): |
| 272 | + render_fn = env._env.render |
| 273 | + if render_fn is None: |
| 274 | + return None |
| 275 | + |
| 276 | + raw = render_fn() |
| 277 | + if raw is None: |
| 278 | + return None |
| 279 | + |
| 280 | + if isinstance(raw, np.ndarray): |
| 281 | + raw = torch.from_numpy(raw.copy()) |
| 282 | + |
| 283 | + # (H, W, C) -> (C, H, W) |
| 284 | + if raw.ndim == 3 and raw.shape[-1] in (3, 4): |
| 285 | + raw = raw[..., :3] |
| 286 | + raw = raw.permute(2, 0, 1) |
| 287 | + |
| 288 | + return raw.to(torch.uint8) |
| 289 | + |
| 290 | + def shutdown(self) -> None: |
| 291 | + """Shut down the environment.""" |
| 292 | + if hasattr(self, "env") and not self.env.is_closed: |
| 293 | + self.env.close() |
0 commit comments