2121
2222_ENV_IDX_KEY = "env_index"
2323
24+ _POLICY_BACKENDS = ("threading" , "multiprocessing" , "ray" , "monarch" )
25+ _ENV_BACKENDS = ("threading" , "multiprocessing" )
26+
27+
28+ def _make_transport (
29+ policy_backend : str , num_slots : int | None = None
30+ ) -> InferenceTransport :
31+ """Create an :class:`InferenceTransport` from a backend name.
32+
33+ Args:
34+ policy_backend: one of ``"threading"``, ``"multiprocessing"``,
35+ ``"ray"``, or ``"monarch"``.
36+ num_slots: when set and ``policy_backend="threading"``, a
37+ :class:`~torchrl.modules.SlotTransport` is created instead of
38+ the generic :class:`~torchrl.modules.ThreadingTransport`.
39+ """
40+ if policy_backend == "threading" :
41+ if num_slots is not None :
42+ from torchrl .modules .inference_server ._slot import SlotTransport
43+
44+ return SlotTransport (num_slots )
45+ return ThreadingTransport ()
46+ if policy_backend == "multiprocessing" :
47+ from torchrl .modules .inference_server ._mp import MPTransport
48+
49+ return MPTransport ()
50+ if policy_backend == "ray" :
51+ from torchrl .modules .inference_server ._ray import RayTransport
52+
53+ return RayTransport ()
54+ if policy_backend == "monarch" :
55+ from torchrl .modules .inference_server ._monarch import MonarchTransport
56+
57+ return MonarchTransport ()
58+ raise ValueError (
59+ f"Unknown policy_backend { policy_backend !r} . "
60+ f"Expected one of { _POLICY_BACKENDS } ."
61+ )
62+
2463
2564def _env_loop (
2665 pool : AsyncEnvPool ,
@@ -47,9 +86,7 @@ def _env_loop(
4786
4887 while not shutdown_event .is_set ():
4988 pool .async_step_and_maybe_reset_send (action_td , env_index = env_id )
50- cur_td , next_obs = pool .async_step_and_maybe_reset_recv (
51- env_index = env_id
52- )
89+ cur_td , next_obs = pool .async_step_and_maybe_reset_recv (env_index = env_id )
5390 cur_td .set (_ENV_IDX_KEY , env_id )
5491 result_queue .put (cur_td )
5592 if shutdown_event .is_set ():
@@ -104,22 +141,35 @@ class AsyncBatchedCollector(BaseCollector):
104141 max_batch_size (int, optional): upper bound on the number of
105142 requests the inference server processes in a single forward pass.
106143 Defaults to ``64``.
144+ min_batch_size (int, optional): minimum number of requests the
145+ inference server accumulates before dispatching a batch. After
146+ the first request arrives the server keeps draining for up to
147+ ``server_timeout`` seconds until this many items are collected.
148+ ``1`` (default) dispatches immediately.
107149 server_timeout (float, optional): seconds the server waits for work
108150 before dispatching a partial batch. Defaults to ``0.01``.
109151 transport (InferenceTransport, optional): a pre-built transport
110- backend. When ``None`` (default) a
111- :class:`~torchrl.modules.ThreadingTransport` is created
112- automatically (since worker threads always live in the main
113- process). Pass a :class:`~torchrl.modules.RayTransport` or
114- :class:`~torchrl.modules.MonarchTransport` for distributed
115- setups where the inference server is remote.
152+ object. When provided, it takes precedence over
153+ ``policy_backend``. When ``None`` (default) a transport is
154+ created automatically from the resolved ``policy_backend``.
116155 device (torch.device or str, optional): device for policy inference.
117156 Passed to the inference server. Defaults to ``None``.
118- backend (str, optional): backend for the
157+ backend (str, optional): global default backend for both
158+ environments and policy inference. Specific overrides
159+ ``env_backend`` and ``policy_backend`` take precedence when set.
160+ One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or
161+ ``"monarch"``. Defaults to ``"threading"``.
162+ env_backend (str, optional): backend for the
119163 :class:`~torchrl.envs.AsyncEnvPool` that runs environments. One
120- of ``"threading"`` or ``"multiprocessing"``. The coordinator
121- threads are always Python threads regardless of this setting.
122- Defaults to ``"threading"``.
164+ of ``"threading"`` or ``"multiprocessing"``. Falls back to
165+ ``backend`` when ``None``. The coordinator threads are always
166+ Python threads regardless of this setting. Defaults to ``None``.
167+ policy_backend (str, optional): backend for the inference transport
168+ used to communicate with the
169+ :class:`~torchrl.modules.InferenceServer`. One of
170+ ``"threading"``, ``"multiprocessing"``, ``"ray"``, or
171+ ``"monarch"``. Falls back to ``backend`` when ``None``.
172+ Defaults to ``None``.
123173 reset_at_each_iter (bool, optional): whether to reset all envs at the
124174 start of every collection batch. Defaults to ``False``.
125175 postproc (Callable, optional): post-processing transform applied to
@@ -169,10 +219,16 @@ def __init__(
169219 frames_per_batch : int ,
170220 total_frames : int = - 1 ,
171221 max_batch_size : int = 64 ,
222+ min_batch_size : int = 1 ,
172223 server_timeout : float = 0.01 ,
173224 transport : InferenceTransport | None = None ,
174225 device : torch .device | str | None = None ,
175- backend : Literal ["threading" , "multiprocessing" ] = "threading" ,
226+ backend : Literal [
227+ "threading" , "multiprocessing" , "ray" , "monarch"
228+ ] = "threading" ,
229+ env_backend : Literal ["threading" , "multiprocessing" ] | None = None ,
230+ policy_backend : Literal ["threading" , "multiprocessing" , "ray" , "monarch" ]
231+ | None = None ,
176232 reset_at_each_iter : bool = False ,
177233 postproc : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
178234 yield_completed_trajectories : bool = False ,
@@ -196,19 +252,34 @@ def __init__(
196252 raise TypeError ("create_env_fn must be a list of env factories." )
197253 self ._create_env_fn = list (create_env_fn )
198254 self ._num_envs = len (create_env_fn )
199- self ._backend = backend
200255 self ._create_env_kwargs = create_env_kwargs
201256
257+ # ---- resolve backends -------------------------------------------------
258+ effective_env_backend = env_backend if env_backend is not None else backend
259+ effective_policy_backend = (
260+ policy_backend if policy_backend is not None else backend
261+ )
262+ if effective_env_backend not in _ENV_BACKENDS :
263+ raise ValueError (
264+ f"env_backend={ effective_env_backend !r} is not supported. "
265+ f"Expected one of { _ENV_BACKENDS } ."
266+ )
267+ self ._env_backend = effective_env_backend
268+ self ._policy_backend = effective_policy_backend
269+
202270 # ---- build transport --------------------------------------------------
203271 if transport is None :
204- transport = ThreadingTransport ()
272+ transport = _make_transport (
273+ effective_policy_backend , num_slots = self ._num_envs
274+ )
205275 self ._transport = transport
206276
207277 # ---- build inference server -------------------------------------------
208278 self ._server = InferenceServer (
209279 model = policy ,
210280 transport = transport ,
211281 max_batch_size = max_batch_size ,
282+ min_batch_size = min_batch_size ,
212283 timeout = server_timeout ,
213284 device = device ,
214285 weight_sync = weight_sync ,
@@ -252,7 +323,7 @@ def _ensure_started(self) -> None:
252323 kwargs ["create_env_kwargs" ] = self ._create_env_kwargs
253324 self ._env_pool = AsyncEnvPool (
254325 self ._create_env_fn ,
255- backend = self ._backend ,
326+ backend = self ._env_backend ,
256327 ** kwargs ,
257328 )
258329
@@ -303,9 +374,18 @@ def _rollout_frames(self) -> TensorDictBase:
303374 transitions : list [TensorDictBase ] = []
304375
305376 while collected < self .frames_per_batch :
377+ # Block for at least one transition
306378 td = rq .get ()
307379 transitions .append (td )
308380 collected += td .numel ()
381+ # Batch-drain any additional items already in the queue
382+ while collected < self .frames_per_batch :
383+ try :
384+ td = rq .get_nowait ()
385+ except queue .Empty :
386+ break
387+ transitions .append (td )
388+ collected += td .numel ()
309389 if self .verbose :
310390 torchrl_logger .debug (
311391 f"AsyncBatchedCollector: { collected } /{ self .frames_per_batch } frames"
0 commit comments