44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7- import multiprocessing as mp
87import queue
98import threading
109from collections import deque , OrderedDict
1615
1716from torchrl ._utils import logger as torchrl_logger
1817from torchrl .collectors ._base import BaseCollector
19- from torchrl .data .utils import CloudpickleWrapper
20- from torchrl .envs import EnvBase
18+ from torchrl .envs import AsyncEnvPool , EnvBase
2119from torchrl .modules .inference_server import InferenceServer , ThreadingTransport
22- from torchrl .modules .inference_server ._mp import MPTransport
2320from torchrl .modules .inference_server ._transport import InferenceTransport
2421
2522_ENV_IDX_KEY = "env_index"
2623
2724
28- def _threading_env_loop (
29- env_factory : Callable ,
30- create_env_kwargs : dict ,
25+ def _env_loop (
26+ pool : AsyncEnvPool ,
27+ env_id : int ,
3128 transport : InferenceTransport ,
3229 result_queue : queue .Queue ,
3330 shutdown_event : threading .Event ,
34- env_id : int ,
3531):
36- """Per-env worker thread that submits directly to the InferenceServer.
32+ """Per-env worker thread using pool slot for env execution and InferenceServer for policy .
3733
38- Each worker owns one environment and one inference client. The
39- client blocks until the server has batched and processed the
40- observation, so the worker loop is simply:
34+ Each thread owns one slot in the :class:`~torchrl.envs.AsyncEnvPool` and
35+ one inference client. The pool handles the actual environment execution in
36+ whatever backend it was configured with (threading, multiprocessing, etc.),
37+ while this thread coordinates the send/recv cycle and inference submission.
4138
42- reset -> infer (blocking) -> step -> put transition -> infer -> ...
39+ reset -> infer (blocking) -> step_send -> step_recv -> put transition -> infer -> ...
4340 """
44- env = env_factory (** create_env_kwargs )
4541 client = transport .client ()
4642
4743 try :
48- obs = env .reset ()
44+ pool .async_reset_send (env_index = env_id )
45+ obs = pool .async_reset_recv (env_index = env_id )
4946 action_td = client (obs )
5047
5148 while not shutdown_event .is_set ():
52- cur_td , next_obs = env .step_and_maybe_reset (action_td )
53- cur_td .set (_ENV_IDX_KEY , env_id )
54- result_queue .put (cur_td )
55- if shutdown_event .is_set ():
56- break
57- action_td = client (next_obs )
58- except Exception :
59- if not shutdown_event .is_set ():
60- raise
61- finally :
62- env .close ()
63-
64-
65- def _mp_env_loop (
66- env_factory : Callable ,
67- create_env_kwargs : dict ,
68- client ,
69- result_queue ,
70- shutdown_event ,
71- env_id : int ,
72- ):
73- """Per-env worker process that submits directly to the InferenceServer.
74-
75- Identical to :func:`_threading_env_loop` but designed for
76- :class:`multiprocessing.Process` workers. The ``client`` is a
77- pre-created :class:`_MPInferenceClient` whose underlying
78- ``mp.Queue`` handles are inherited by the child process.
79- """
80- if isinstance (env_factory , CloudpickleWrapper ):
81- env_factory = env_factory .fn
82- env = env_factory (** create_env_kwargs )
83-
84- try :
85- obs = env .reset ()
86- action_td = client (obs )
87-
88- while not shutdown_event .is_set ():
89- cur_td , next_obs = env .step_and_maybe_reset (action_td )
49+ pool .async_step_and_maybe_reset_send (action_td , env_index = env_id )
50+ cur_td , next_obs = pool .async_step_and_maybe_reset_recv (
51+ env_index = env_id
52+ )
9053 cur_td .set (_ENV_IDX_KEY , env_id )
9154 result_queue .put (cur_td )
9255 if shutdown_event .is_set ():
@@ -95,23 +58,25 @@ def _mp_env_loop(
9558 except Exception :
9659 if not shutdown_event .is_set ():
9760 raise
98- finally :
99- env .close ()
10061
10162
10263class AsyncBatchedCollector (BaseCollector ):
103- """Asynchronous collector that pairs per-env workers with an :class:`~torchrl.modules.InferenceServer`.
64+ """Asynchronous collector that pairs per-env threads with an :class:`~torchrl.envs.AsyncEnvPool` and an :class:`~torchrl.modules.InferenceServer`.
10465
10566 Unlike :class:`~torchrl.collectors.Collector`, this collector fully
10667 decouples environment stepping from policy inference:
10768
108- * Each environment runs in its own worker (thread or process) and
109- submits observations directly to the inference server.
110- * An :class:`~torchrl.modules.InferenceServer` running in a background
69+ * An :class:`~torchrl.envs.AsyncEnvPool` runs *N* environments using
70+ whatever backend the user chooses (``"threading"``,
71+ ``"multiprocessing"``).
72+ * *N* lightweight coordinator threads -- one per environment -- each own
73+ a slot in the pool and an inference client. A thread sends its env's
74+ observation to the :class:`~torchrl.modules.InferenceServer`, blocks
75+ until the batched action is returned, then sends the action back to
76+ the pool for stepping.
77+ * The :class:`~torchrl.modules.InferenceServer` running in a background
11178 thread continuously drains observation submissions, batches them, runs
11279 a single forward pass, and fans actions back out.
113- * Workers block on a ``Future`` while waiting for inference, releasing
114- the GIL so other workers and the server can proceed.
11580
11681 There is **no global synchronisation barrier**: fast environments keep
11782 stepping while slow ones wait for inference, and the server always
@@ -142,18 +107,19 @@ class AsyncBatchedCollector(BaseCollector):
142107 server_timeout (float, optional): seconds the server waits for work
143108 before dispatching a partial batch. Defaults to ``0.01``.
144109 transport (InferenceTransport, optional): a pre-built transport
145- backend. When ``None`` (default) one is created automatically
146- to match the ``backend`` (`` ThreadingTransport`` for
147- ``"threading"``, ``MPTransport`` for ``"multiprocessing"``).
148- Pass a :class:`~torchrl.modules.RayTransport` or
110+ backend. When ``None`` (default) a
111+ :class:`~torchrl.modules. ThreadingTransport` is created
112+ automatically (since worker threads always live in the main
113+ process). Pass a :class:`~torchrl.modules.RayTransport` or
149114 :class:`~torchrl.modules.MonarchTransport` for distributed
150- setups (workers will be spawned as threads that hold
151- Ray/Monarch clients).
115+ setups where the inference server is remote.
152116 device (torch.device or str, optional): device for policy inference.
153117 Passed to the inference server. Defaults to ``None``.
154- backend (str, optional): how to run per-env workers. One of
155- ``"threading"`` or ``"multiprocessing"``. Defaults to
156- ``"threading"``.
118+ backend (str, optional): backend for the
119+ :class:`~torchrl.envs.AsyncEnvPool` that runs environments. One
120+ of ``"threading"`` or ``"multiprocessing"``. The coordinator
121+ threads are always Python threads regardless of this setting.
122+ Defaults to ``"threading"``.
157123 reset_at_each_iter (bool, optional): whether to reset all envs at the
158124 start of every collection batch. Defaults to ``False``.
159125 postproc (Callable, optional): post-processing transform applied to
@@ -235,9 +201,7 @@ def __init__(
235201
236202 # ---- build transport --------------------------------------------------
237203 if transport is None :
238- transport = (
239- MPTransport () if backend == "multiprocessing" else ThreadingTransport ()
240- )
204+ transport = ThreadingTransport ()
241205 self ._transport = transport
242206
243207 # ---- build inference server -------------------------------------------
@@ -264,9 +228,10 @@ def __init__(
264228 self ._iter = - 1
265229
266230 # ---- runtime state (created lazily) -----------------------------------
267- self ._shutdown_event : threading .Event | mp .Event = None
268- self ._result_queue : queue .Queue | mp .Queue = None
269- self ._workers : list = []
231+ self ._shutdown_event : threading .Event | None = None
232+ self ._result_queue : queue .Queue | None = None
233+ self ._env_pool : AsyncEnvPool | None = None
234+ self ._workers : list [threading .Thread ] = []
270235
271236 # Per-env trajectory accumulators (for yield_completed_trajectories)
272237 self ._yield_queues : list [deque ] = [deque () for _ in range (self ._num_envs )]
@@ -276,82 +241,51 @@ def __init__(
276241 # Lifecycle
277242 # ------------------------------------------------------------------
278243
279- def _normalise_env_kwargs (self ) -> list [dict ]:
280- env_kwargs = self ._create_env_kwargs
281- if env_kwargs is None :
282- return [{}] * self ._num_envs
283- if isinstance (env_kwargs , dict ):
284- return [env_kwargs ] * self ._num_envs
285- return list (env_kwargs )
286-
287244 def _ensure_started (self ) -> None :
288- """Start the inference server and spawn per-env workers."""
289- if self ._workers and all (
290- (w .is_alive () if hasattr (w , "is_alive" ) else True ) for w in self ._workers
291- ):
245+ """Create the env pool, start the server and per-env threads."""
246+ if self ._workers and all (w .is_alive () for w in self ._workers ):
292247 return
293248
249+ # Build env pool
250+ kwargs = {}
251+ if self ._create_env_kwargs is not None :
252+ kwargs ["create_env_kwargs" ] = self ._create_env_kwargs
253+ self ._env_pool = AsyncEnvPool (
254+ self ._create_env_fn ,
255+ backend = self ._backend ,
256+ ** kwargs ,
257+ )
258+
259+ # Start inference server
294260 if not self ._server .is_alive :
295261 self ._server .start ()
296262
297- env_kwargs = self ._normalise_env_kwargs ()
298-
299- if self ._backend == "multiprocessing" :
300- self ._start_mp_workers (env_kwargs )
301- else :
302- self ._start_threading_workers (env_kwargs )
303-
304- def _start_threading_workers (self , env_kwargs : list [dict ]) -> None :
263+ # Start per-env coordinator threads
305264 self ._result_queue = queue .Queue ()
306265 self ._shutdown_event = threading .Event ()
307266
308267 self ._workers = []
309268 for i in range (self ._num_envs ):
310269 t = threading .Thread (
311- target = _threading_env_loop ,
270+ target = _env_loop ,
312271 kwargs = {
313- "env_factory " : self ._create_env_fn [ i ] ,
314- "create_env_kwargs " : env_kwargs [ i ] ,
272+ "pool " : self ._env_pool ,
273+ "env_id " : i ,
315274 "transport" : self ._transport ,
316275 "result_queue" : self ._result_queue ,
317276 "shutdown_event" : self ._shutdown_event ,
318- "env_id" : i ,
319277 },
320278 daemon = True ,
321279 name = f"AsyncBatchedCollector-env-{ i } " ,
322280 )
323281 self ._workers .append (t )
324282 t .start ()
325283
326- def _start_mp_workers (self , env_kwargs : list [dict ]) -> None :
327- ctx = mp .get_context ("spawn" )
328- self ._result_queue = ctx .Queue ()
329- self ._shutdown_event = ctx .Event ()
330-
331- # Pre-create one client per env before spawning (queues are inherited)
332- clients = [self ._transport .client () for _ in range (self ._num_envs )]
333-
334- self ._workers = []
335- for i in range (self ._num_envs ):
336- env_fn = self ._create_env_fn [i ]
337- if not isinstance (env_fn , EnvBase ) and env_fn .__class__ .__name__ != "EnvCreator" :
338- env_fn = CloudpickleWrapper (env_fn )
339-
340- p = ctx .Process (
341- target = _mp_env_loop ,
342- kwargs = {
343- "env_factory" : env_fn ,
344- "create_env_kwargs" : env_kwargs [i ],
345- "client" : clients [i ],
346- "result_queue" : self ._result_queue ,
347- "shutdown_event" : self ._shutdown_event ,
348- "env_id" : i ,
349- },
350- daemon = True ,
351- name = f"AsyncBatchedCollector-env-{ i } " ,
352- )
353- self ._workers .append (p )
354- p .start ()
284+ @property
285+ def env (self ) -> AsyncEnvPool :
286+ """The underlying :class:`AsyncEnvPool`."""
287+ self ._ensure_started ()
288+ return self ._env_pool
355289
356290 @property
357291 def policy (self ) -> Callable :
@@ -434,21 +368,20 @@ def shutdown(
434368 close_env : bool = True ,
435369 raise_on_error : bool = True ,
436370 ) -> None :
437- """Shut down the collector, inference server and workers ."""
371+ """Shut down the collector, inference server, threads and env pool ."""
438372 if self ._shutdown_event is not None :
439373 self ._shutdown_event .set ()
440374 _timeout = timeout or 5.0
441375 for w in self ._workers :
442376 w .join (timeout = _timeout )
443- # Terminate any stragglers (multiprocessing only)
444- for w in self ._workers :
445- if hasattr (w , "terminate" ) and w .is_alive ():
446- w .terminate ()
447377 self ._workers = []
448378 self ._server .shutdown (timeout = _timeout )
379+ if close_env and self ._env_pool is not None :
380+ self ._env_pool .close (raise_if_closed = raise_on_error )
381+ self ._env_pool = None
449382
450383 def set_seed (self , seed : int , static_seed : bool = False ) -> int :
451- """Set the seed (no-op; envs are created inside workers )."""
384+ """Set the seed (no-op; envs are created inside the pool )."""
452385 return seed
453386
454387 def state_dict (self ) -> OrderedDict :
0 commit comments