Skip to content

Commit 2c0f3ee

Browse files
committed
Update
[ghstack-poisoned]
1 parent 16db465 commit 2c0f3ee

File tree

3 files changed

+226
-211
lines changed

3 files changed

+226
-211
lines changed

examples/collectors/async_batched_collector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
through an :class:`~torchrl.modules.InferenceServer`.
66
77
Architecture:
8-
- Each environment runs in its own worker (thread or process).
8+
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel
9+
using the chosen backend (``"threading"`` or ``"multiprocessing"``).
10+
- One lightweight coordinator thread per environment owns a slot in the pool
11+
and an inference client.
912
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
1013
and runs a single forward pass.
11-
- Workers submit observations directly to the server and block until the
12-
action is ready. There is no global synchronisation barrier -- fast envs
13-
keep stepping while slow ones wait for inference.
14+
- There is no global synchronisation barrier -- fast envs keep stepping
15+
while slow ones wait for inference.
1416
1517
The user only supplies:
1618
- A list of environment factories

torchrl/collectors/_async_batched.py

Lines changed: 68 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import multiprocessing as mp
87
import queue
98
import threading
109
from collections import deque, OrderedDict
@@ -16,77 +15,41 @@
1615

1716
from torchrl._utils import logger as torchrl_logger
1817
from torchrl.collectors._base import BaseCollector
19-
from torchrl.data.utils import CloudpickleWrapper
20-
from torchrl.envs import EnvBase
18+
from torchrl.envs import AsyncEnvPool, EnvBase
2119
from torchrl.modules.inference_server import InferenceServer, ThreadingTransport
22-
from torchrl.modules.inference_server._mp import MPTransport
2320
from torchrl.modules.inference_server._transport import InferenceTransport
2421

2522
_ENV_IDX_KEY = "env_index"
2623

2724

28-
def _threading_env_loop(
29-
env_factory: Callable,
30-
create_env_kwargs: dict,
25+
def _env_loop(
26+
pool: AsyncEnvPool,
27+
env_id: int,
3128
transport: InferenceTransport,
3229
result_queue: queue.Queue,
3330
shutdown_event: threading.Event,
34-
env_id: int,
3531
):
36-
"""Per-env worker thread that submits directly to the InferenceServer.
32+
"""Per-env worker thread using pool slot for env execution and InferenceServer for policy.
3733
38-
Each worker owns one environment and one inference client. The
39-
client blocks until the server has batched and processed the
40-
observation, so the worker loop is simply:
34+
Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and
35+
one inference client. The pool handles the actual environment execution in
36+
whatever backend it was configured with (threading, multiprocessing, etc.),
37+
while this thread coordinates the send/recv cycle and inference submission.
4138
42-
reset -> infer (blocking) -> step -> put transition -> infer -> ...
39+
reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ...
4340
"""
44-
env = env_factory(**create_env_kwargs)
4541
client = transport.client()
4642

4743
try:
48-
obs = env.reset()
44+
pool.async_reset_send(env_index=env_id)
45+
obs = pool.async_reset_recv(env_index=env_id)
4946
action_td = client(obs)
5047

5148
while not shutdown_event.is_set():
52-
cur_td, next_obs = env.step_and_maybe_reset(action_td)
53-
cur_td.set(_ENV_IDX_KEY, env_id)
54-
result_queue.put(cur_td)
55-
if shutdown_event.is_set():
56-
break
57-
action_td = client(next_obs)
58-
except Exception:
59-
if not shutdown_event.is_set():
60-
raise
61-
finally:
62-
env.close()
63-
64-
65-
def _mp_env_loop(
66-
env_factory: Callable,
67-
create_env_kwargs: dict,
68-
client,
69-
result_queue,
70-
shutdown_event,
71-
env_id: int,
72-
):
73-
"""Per-env worker process that submits directly to the InferenceServer.
74-
75-
Identical to :func:`_threading_env_loop` but designed for
76-
:class:`multiprocessing.Process` workers. The ``client`` is a
77-
pre-created :class:`_MPInferenceClient` whose underlying
78-
``mp.Queue`` handles are inherited by the child process.
79-
"""
80-
if isinstance(env_factory, CloudpickleWrapper):
81-
env_factory = env_factory.fn
82-
env = env_factory(**create_env_kwargs)
83-
84-
try:
85-
obs = env.reset()
86-
action_td = client(obs)
87-
88-
while not shutdown_event.is_set():
89-
cur_td, next_obs = env.step_and_maybe_reset(action_td)
49+
pool.async_step_and_maybe_reset_send(action_td, env_index=env_id)
50+
cur_td, next_obs = pool.async_step_and_maybe_reset_recv(
51+
env_index=env_id
52+
)
9053
cur_td.set(_ENV_IDX_KEY, env_id)
9154
result_queue.put(cur_td)
9255
if shutdown_event.is_set():
@@ -95,23 +58,25 @@ def _mp_env_loop(
9558
except Exception:
9659
if not shutdown_event.is_set():
9760
raise
98-
finally:
99-
env.close()
10061

10162

10263
class AsyncBatchedCollector(BaseCollector):
103-
"""Asynchronous collector that pairs per-env workers with an :class:`~torchrl.modules.InferenceServer`.
64+
"""Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`.
10465
10566
Unlike :class:`~torchrl.collectors.Collector`, this collector fully
10667
decouples environment stepping from policy inference:
10768
108-
* Each environment runs in its own worker (thread or process) and
109-
submits observations directly to the inference server.
110-
* An :class:`~torchrl.modules.InferenceServer` running in a background
69+
* An :class:`~torchrl.envs.AsyncEnvPool` runs *N* environments using
70+
whatever backend the user chooses (``"threading"``,
71+
``"multiprocessing"``).
72+
* *N* lightweight coordinator threads -- one per environment -- each own
73+
a slot in the pool and an inference client. A thread sends its env's
74+
observation to the :class:`~torchrl.modules.InferenceServer`, blocks
75+
until the batched action is returned, then sends the action back to
76+
the pool for stepping.
77+
* The :class:`~torchrl.modules.InferenceServer` running in a background
11178
thread continuously drains observation submissions, batches them, runs
11279
a single forward pass, and fans actions back out.
113-
* Workers block on a ``Future`` while waiting for inference, releasing
114-
the GIL so other workers and the server can proceed.
11580
11681
There is **no global synchronisation barrier**: fast environments keep
11782
stepping while slow ones wait for inference, and the server always
@@ -142,18 +107,19 @@ class AsyncBatchedCollector(BaseCollector):
142107
server_timeout (float, optional): seconds the server waits for work
143108
before dispatching a partial batch. Defaults to ``0.01``.
144109
transport (InferenceTransport, optional): a pre-built transport
145-
backend. When ``None`` (default) one is created automatically
146-
to match the ``backend`` (``ThreadingTransport`` for
147-
``"threading"``, ``MPTransport`` for ``"multiprocessing"``).
148-
Pass a :class:`~torchrl.modules.RayTransport` or
110+
backend. When ``None`` (default) a
111+
:class:`~torchrl.modules.ThreadingTransport` is created
112+
automatically (since worker threads always live in the main
113+
process). Pass a :class:`~torchrl.modules.RayTransport` or
149114
:class:`~torchrl.modules.MonarchTransport` for distributed
150-
setups (workers will be spawned as threads that hold
151-
Ray/Monarch clients).
115+
setups where the inference server is remote.
152116
device (torch.device or str, optional): device for policy inference.
153117
Passed to the inference server. Defaults to ``None``.
154-
backend (str, optional): how to run per-env workers. One of
155-
``"threading"`` or ``"multiprocessing"``. Defaults to
156-
``"threading"``.
118+
backend (str, optional): backend for the
119+
:class:`~torchrl.envs.AsyncEnvPool` that runs environments. One
120+
of ``"threading"`` or ``"multiprocessing"``. The coordinator
121+
threads are always Python threads regardless of this setting.
122+
Defaults to ``"threading"``.
157123
reset_at_each_iter (bool, optional): whether to reset all envs at the
158124
start of every collection batch. Defaults to ``False``.
159125
postproc (Callable, optional): post-processing transform applied to
@@ -235,9 +201,7 @@ def __init__(
235201

236202
# ---- build transport --------------------------------------------------
237203
if transport is None:
238-
transport = (
239-
MPTransport() if backend == "multiprocessing" else ThreadingTransport()
240-
)
204+
transport = ThreadingTransport()
241205
self._transport = transport
242206

243207
# ---- build inference server -------------------------------------------
@@ -264,9 +228,10 @@ def __init__(
264228
self._iter = -1
265229

266230
# ---- runtime state (created lazily) -----------------------------------
267-
self._shutdown_event: threading.Event | mp.Event = None
268-
self._result_queue: queue.Queue | mp.Queue = None
269-
self._workers: list = []
231+
self._shutdown_event: threading.Event | None = None
232+
self._result_queue: queue.Queue | None = None
233+
self._env_pool: AsyncEnvPool | None = None
234+
self._workers: list[threading.Thread] = []
270235

271236
# Per-env trajectory accumulators (for yield_completed_trajectories)
272237
self._yield_queues: list[deque] = [deque() for _ in range(self._num_envs)]
@@ -276,82 +241,51 @@ def __init__(
276241
# Lifecycle
277242
# ------------------------------------------------------------------
278243

279-
def _normalise_env_kwargs(self) -> list[dict]:
280-
env_kwargs = self._create_env_kwargs
281-
if env_kwargs is None:
282-
return [{}] * self._num_envs
283-
if isinstance(env_kwargs, dict):
284-
return [env_kwargs] * self._num_envs
285-
return list(env_kwargs)
286-
287244
def _ensure_started(self) -> None:
288-
"""Start the inference server and spawn per-env workers."""
289-
if self._workers and all(
290-
(w.is_alive() if hasattr(w, "is_alive") else True) for w in self._workers
291-
):
245+
"""Create the env pool, start the server and per-env threads."""
246+
if self._workers and all(w.is_alive() for w in self._workers):
292247
return
293248

249+
# Build env pool
250+
kwargs = {}
251+
if self._create_env_kwargs is not None:
252+
kwargs["create_env_kwargs"] = self._create_env_kwargs
253+
self._env_pool = AsyncEnvPool(
254+
self._create_env_fn,
255+
backend=self._backend,
256+
**kwargs,
257+
)
258+
259+
# Start inference server
294260
if not self._server.is_alive:
295261
self._server.start()
296262

297-
env_kwargs = self._normalise_env_kwargs()
298-
299-
if self._backend == "multiprocessing":
300-
self._start_mp_workers(env_kwargs)
301-
else:
302-
self._start_threading_workers(env_kwargs)
303-
304-
def _start_threading_workers(self, env_kwargs: list[dict]) -> None:
263+
# Start per-env coordinator threads
305264
self._result_queue = queue.Queue()
306265
self._shutdown_event = threading.Event()
307266

308267
self._workers = []
309268
for i in range(self._num_envs):
310269
t = threading.Thread(
311-
target=_threading_env_loop,
270+
target=_env_loop,
312271
kwargs={
313-
"env_factory": self._create_env_fn[i],
314-
"create_env_kwargs": env_kwargs[i],
272+
"pool": self._env_pool,
273+
"env_id": i,
315274
"transport": self._transport,
316275
"result_queue": self._result_queue,
317276
"shutdown_event": self._shutdown_event,
318-
"env_id": i,
319277
},
320278
daemon=True,
321279
name=f"AsyncBatchedCollector-env-{i}",
322280
)
323281
self._workers.append(t)
324282
t.start()
325283

326-
def _start_mp_workers(self, env_kwargs: list[dict]) -> None:
327-
ctx = mp.get_context("spawn")
328-
self._result_queue = ctx.Queue()
329-
self._shutdown_event = ctx.Event()
330-
331-
# Pre-create one client per env before spawning (queues are inherited)
332-
clients = [self._transport.client() for _ in range(self._num_envs)]
333-
334-
self._workers = []
335-
for i in range(self._num_envs):
336-
env_fn = self._create_env_fn[i]
337-
if not isinstance(env_fn, EnvBase) and env_fn.__class__.__name__ != "EnvCreator":
338-
env_fn = CloudpickleWrapper(env_fn)
339-
340-
p = ctx.Process(
341-
target=_mp_env_loop,
342-
kwargs={
343-
"env_factory": env_fn,
344-
"create_env_kwargs": env_kwargs[i],
345-
"client": clients[i],
346-
"result_queue": self._result_queue,
347-
"shutdown_event": self._shutdown_event,
348-
"env_id": i,
349-
},
350-
daemon=True,
351-
name=f"AsyncBatchedCollector-env-{i}",
352-
)
353-
self._workers.append(p)
354-
p.start()
284+
@property
285+
def env(self) -> AsyncEnvPool:
286+
"""The underlying :class:`AsyncEnvPool`."""
287+
self._ensure_started()
288+
return self._env_pool
355289

356290
@property
357291
def policy(self) -> Callable:
@@ -434,21 +368,20 @@ def shutdown(
434368
close_env: bool = True,
435369
raise_on_error: bool = True,
436370
) -> None:
437-
"""Shut down the collector, inference server and workers."""
371+
"""Shut down the collector, inference server, threads and env pool."""
438372
if self._shutdown_event is not None:
439373
self._shutdown_event.set()
440374
_timeout = timeout or 5.0
441375
for w in self._workers:
442376
w.join(timeout=_timeout)
443-
# Terminate any stragglers (multiprocessing only)
444-
for w in self._workers:
445-
if hasattr(w, "terminate") and w.is_alive():
446-
w.terminate()
447377
self._workers = []
448378
self._server.shutdown(timeout=_timeout)
379+
if close_env and self._env_pool is not None:
380+
self._env_pool.close(raise_if_closed=raise_on_error)
381+
self._env_pool = None
449382

450383
def set_seed(self, seed: int, static_seed: bool = False) -> int:
451-
"""Set the seed (no-op; envs are created inside workers)."""
384+
"""Set the seed (no-op; envs are created inside the pool)."""
452385
return seed
453386

454387
def state_dict(self) -> OrderedDict:

0 commit comments

Comments
 (0)