Skip to content

Commit 50f9bf4

Browse files
committed
Update
[ghstack-poisoned]
1 parent 122bc89 commit 50f9bf4

File tree

8 files changed

+335
-43
lines changed

8 files changed

+335
-43
lines changed

benchmarks/bench_collectors.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
2. Collector (ParallelEnv x N) -- single-process, N envs in sub-procs
1313
3. MultiCollector (sync, x N) -- N sub-processes, sync delivery
1414
4. MultiCollector (async, x N) -- N sub-processes, async delivery
15-
5. AsyncBatchedCollector (threading) -- AsyncEnvPool + InferenceServer
15+
5. AsyncBatched (env=thread, pol=thread) -- threading pool + threading transport
16+
6. AsyncBatched (env=mp, pol=thread) -- multiprocessing pool + threading transport
1617
"""
1718
from __future__ import annotations
1819

@@ -368,33 +369,33 @@ def policy_factory():
368369
)
369370
)
370371

371-
# 5. AsyncBatchedCollector (threading backend)
372+
# 5. AsyncBatchedCollector (env=threading, policy=threading)
372373
results.append(
373374
bench(
374-
f"AsyncBatchedCollector threading (x{num_envs})",
375+
f"AsyncBatched env=thread pol=thread (x{num_envs})",
375376
lambda: AsyncBatchedCollector(
376377
create_env_fn=[make_env_fn] * num_envs,
377378
policy=policy_factory(),
378379
frames_per_batch=frames_per_batch,
379380
total_frames=-1,
380381
max_batch_size=num_envs,
381-
backend="threading",
382+
env_backend="threading",
382383
),
383384
target_frames=total_frames,
384385
)
385386
)
386387

387-
# 6. AsyncBatchedCollector (multiprocessing backend)
388+
# 6. AsyncBatchedCollector (env=multiprocessing, policy=threading)
388389
results.append(
389390
bench(
390-
f"AsyncBatchedCollector mp (x{num_envs})",
391+
f"AsyncBatched env=mp pol=thread (x{num_envs})",
391392
lambda: AsyncBatchedCollector(
392393
create_env_fn=[make_env_fn] * num_envs,
393394
policy=policy_factory(),
394395
frames_per_batch=frames_per_batch,
395396
total_frames=-1,
396397
max_batch_size=num_envs,
397-
backend="multiprocessing",
398+
env_backend="multiprocessing",
398399
),
399400
target_frames=total_frames,
400401
)

examples/collectors/async_batched_collector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@
66
77
Architecture:
88
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel
9-
using the chosen backend (``"threading"`` or ``"multiprocessing"``).
9+
using the chosen ``env_backend`` (``"threading"`` or ``"multiprocessing"``).
1010
- One lightweight coordinator thread per environment owns a slot in the pool
1111
and an inference client.
1212
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
13-
and runs a single forward pass.
13+
and runs a single forward pass. The communication layer (transport) is
14+
controlled by ``policy_backend`` (``"threading"``, ``"multiprocessing"``,
15+
``"ray"``, or ``"monarch"``).
1416
- There is no global synchronisation barrier -- fast envs keep stepping
1517
while slow ones wait for inference.
1618
19+
Backend parameters:
20+
- ``backend`` -- global default for both env pool and policy transport.
21+
- ``env_backend`` -- override for the env pool (falls back to ``backend``).
22+
- ``policy_backend`` -- override for the transport (falls back to ``backend``).
23+
1724
The user only supplies:
1825
- A list of environment factories
1926
- A policy (or policy factory)

test/test_inference_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def test_basic_collection(self):
733733
frames_per_batch=frames_per_batch,
734734
total_frames=total_frames,
735735
max_batch_size=num_envs,
736-
backend="threading",
736+
env_backend="threading",
737737
)
738738
total_collected = 0
739739
for batch in collector:
@@ -751,7 +751,7 @@ def test_policy_factory(self):
751751
frames_per_batch=10,
752752
total_frames=20,
753753
max_batch_size=num_envs,
754-
backend="threading",
754+
env_backend="threading",
755755
)
756756
total_collected = 0
757757
for batch in collector:
@@ -791,7 +791,7 @@ def test_yield_completed_trajectories(self):
791791
total_frames=30,
792792
yield_completed_trajectories=True,
793793
max_batch_size=num_envs,
794-
backend="threading",
794+
env_backend="threading",
795795
)
796796
count = 0
797797
for batch in collector:
@@ -809,7 +809,7 @@ def test_shutdown_idempotent(self):
809809
policy=policy,
810810
frames_per_batch=10,
811811
total_frames=10,
812-
backend="threading",
812+
env_backend="threading",
813813
)
814814
# Consume one batch to start
815815
for _batch in collector:
@@ -825,7 +825,7 @@ def test_endless_collector(self):
825825
policy=policy,
826826
frames_per_batch=10,
827827
total_frames=-1,
828-
backend="threading",
828+
env_backend="threading",
829829
)
830830
collected = 0
831831
for batch in collector:
@@ -862,7 +862,7 @@ def postproc(td):
862862
frames_per_batch=10,
863863
total_frames=20,
864864
postproc=postproc,
865-
backend="threading",
865+
env_backend="threading",
866866
)
867867
for _ in collector:
868868
pass

torchrl/collectors/_async_batched.py

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@
2121

2222
_ENV_IDX_KEY = "env_index"
2323

24+
_POLICY_BACKENDS = ("threading", "multiprocessing", "ray", "monarch")
25+
_ENV_BACKENDS = ("threading", "multiprocessing")
26+
27+
28+
def _make_transport(
29+
policy_backend: str, num_slots: int | None = None
30+
) -> InferenceTransport:
31+
"""Create an :class:`InferenceTransport` from a backend name.
32+
33+
Args:
34+
policy_backend: one of ``"threading"``, ``"multiprocessing"``,
35+
``"ray"``, or ``"monarch"``.
36+
num_slots: when set and ``policy_backend="threading"``, a
37+
:class:`~torchrl.modules.SlotTransport` is created instead of
38+
the generic :class:`~torchrl.modules.ThreadingTransport`.
39+
"""
40+
if policy_backend == "threading":
41+
if num_slots is not None:
42+
from torchrl.modules.inference_server._slot import SlotTransport
43+
44+
return SlotTransport(num_slots)
45+
return ThreadingTransport()
46+
if policy_backend == "multiprocessing":
47+
from torchrl.modules.inference_server._mp import MPTransport
48+
49+
return MPTransport()
50+
if policy_backend == "ray":
51+
from torchrl.modules.inference_server._ray import RayTransport
52+
53+
return RayTransport()
54+
if policy_backend == "monarch":
55+
from torchrl.modules.inference_server._monarch import MonarchTransport
56+
57+
return MonarchTransport()
58+
raise ValueError(
59+
f"Unknown policy_backend {policy_backend!r}. "
60+
f"Expected one of {_POLICY_BACKENDS}."
61+
)
62+
2463

2564
def _env_loop(
2665
pool: AsyncEnvPool,
@@ -47,9 +86,7 @@ def _env_loop(
4786

4887
while not shutdown_event.is_set():
4988
pool.async_step_and_maybe_reset_send(action_td, env_index=env_id)
50-
cur_td, next_obs = pool.async_step_and_maybe_reset_recv(
51-
env_index=env_id
52-
)
89+
cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id)
5390
cur_td.set(_ENV_IDX_KEY, env_id)
5491
result_queue.put(cur_td)
5592
if shutdown_event.is_set():
@@ -104,22 +141,35 @@ class AsyncBatchedCollector(BaseCollector):
104141
max_batch_size (int, optional): upper bound on the number of
105142
requests the inference server processes in a single forward pass.
106143
Defaults to ``64``.
144+
min_batch_size (int, optional): minimum number of requests the
145+
inference server accumulates before dispatching a batch. After
146+
the first request arrives the server keeps draining for up to
147+
``server_timeout`` seconds until this many items are collected.
148+
``1`` (default) dispatches immediately.
107149
server_timeout (float, optional): seconds the server waits for work
108150
before dispatching a partial batch. Defaults to ``0.01``.
109151
transport (InferenceTransport, optional): a pre-built transport
110-
backend. When ``None`` (default) a
111-
:class:`~torchrl.modules.ThreadingTransport` is created
112-
automatically (since worker threads always live in the main
113-
process). Pass a :class:`~torchrl.modules.RayTransport` or
114-
:class:`~torchrl.modules.MonarchTransport` for distributed
115-
setups where the inference server is remote.
152+
object. When provided, it takes precedence over
153+
``policy_backend``. When ``None`` (default) a transport is
154+
created automatically from the resolved ``policy_backend``.
116155
device (torch.device or str, optional): device for policy inference.
117156
Passed to the inference server. Defaults to ``None``.
118-
backend (str, optional): backend for the
157+
backend (str, optional): global default backend for both
158+
environments and policy inference. Specific overrides
159+
``env_backend`` and ``policy_backend`` take precedence when set.
160+
One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or
161+
``"monarch"``. Defaults to ``"threading"``.
162+
env_backend (str, optional): backend for the
119163
:class:`~torchrl.envs.AsyncEnvPool` that runs environments. One
120-
of ``"threading"`` or ``"multiprocessing"``. The coordinator
121-
threads are always Python threads regardless of this setting.
122-
Defaults to ``"threading"``.
164+
of ``"threading"`` or ``"multiprocessing"``. Falls back to
165+
``backend`` when ``None``. The coordinator threads are always
166+
Python threads regardless of this setting. Defaults to ``None``.
167+
policy_backend (str, optional): backend for the inference transport
168+
used to communicate with the
169+
:class:`~torchrl.modules.InferenceServer`. One of
170+
``"threading"``, ``"multiprocessing"``, ``"ray"``, or
171+
``"monarch"``. Falls back to ``backend`` when ``None``.
172+
Defaults to ``None``.
123173
reset_at_each_iter (bool, optional): whether to reset all envs at the
124174
start of every collection batch. Defaults to ``False``.
125175
postproc (Callable, optional): post-processing transform applied to
@@ -169,10 +219,16 @@ def __init__(
169219
frames_per_batch: int,
170220
total_frames: int = -1,
171221
max_batch_size: int = 64,
222+
min_batch_size: int = 1,
172223
server_timeout: float = 0.01,
173224
transport: InferenceTransport | None = None,
174225
device: torch.device | str | None = None,
175-
backend: Literal["threading", "multiprocessing"] = "threading",
226+
backend: Literal[
227+
"threading", "multiprocessing", "ray", "monarch"
228+
] = "threading",
229+
env_backend: Literal["threading", "multiprocessing"] | None = None,
230+
policy_backend: Literal["threading", "multiprocessing", "ray", "monarch"]
231+
| None = None,
176232
reset_at_each_iter: bool = False,
177233
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
178234
yield_completed_trajectories: bool = False,
@@ -196,19 +252,34 @@ def __init__(
196252
raise TypeError("create_env_fn must be a list of env factories.")
197253
self._create_env_fn = list(create_env_fn)
198254
self._num_envs = len(create_env_fn)
199-
self._backend = backend
200255
self._create_env_kwargs = create_env_kwargs
201256

257+
# ---- resolve backends -------------------------------------------------
258+
effective_env_backend = env_backend if env_backend is not None else backend
259+
effective_policy_backend = (
260+
policy_backend if policy_backend is not None else backend
261+
)
262+
if effective_env_backend not in _ENV_BACKENDS:
263+
raise ValueError(
264+
f"env_backend={effective_env_backend!r} is not supported. "
265+
f"Expected one of {_ENV_BACKENDS}."
266+
)
267+
self._env_backend = effective_env_backend
268+
self._policy_backend = effective_policy_backend
269+
202270
# ---- build transport --------------------------------------------------
203271
if transport is None:
204-
transport = ThreadingTransport()
272+
transport = _make_transport(
273+
effective_policy_backend, num_slots=self._num_envs
274+
)
205275
self._transport = transport
206276

207277
# ---- build inference server -------------------------------------------
208278
self._server = InferenceServer(
209279
model=policy,
210280
transport=transport,
211281
max_batch_size=max_batch_size,
282+
min_batch_size=min_batch_size,
212283
timeout=server_timeout,
213284
device=device,
214285
weight_sync=weight_sync,
@@ -252,7 +323,7 @@ def _ensure_started(self) -> None:
252323
kwargs["create_env_kwargs"] = self._create_env_kwargs
253324
self._env_pool = AsyncEnvPool(
254325
self._create_env_fn,
255-
backend=self._backend,
326+
backend=self._env_backend,
256327
**kwargs,
257328
)
258329

@@ -303,9 +374,18 @@ def _rollout_frames(self) -> TensorDictBase:
303374
transitions: list[TensorDictBase] = []
304375

305376
while collected < self.frames_per_batch:
377+
# Block for at least one transition
306378
td = rq.get()
307379
transitions.append(td)
308380
collected += td.numel()
381+
# Batch-drain any additional items already in the queue
382+
while collected < self.frames_per_batch:
383+
try:
384+
td = rq.get_nowait()
385+
except queue.Empty:
386+
break
387+
transitions.append(td)
388+
collected += td.numel()
309389
if self.verbose:
310390
torchrl_logger.debug(
311391
f"AsyncBatchedCollector: {collected}/{self.frames_per_batch} frames"

torchrl/envs/async_envs.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -673,9 +673,7 @@ def async_step_and_maybe_reset_send(
673673
for _env_idx, local_td in _zip_strict(env_idx, local_tds):
674674
if not _per_env:
675675
self._current_step_reset = self._current_step_reset + 1
676-
self.input_queue[_env_idx].put(
677-
("step_and_maybe_reset", local_td, _per_env)
678-
)
676+
self.input_queue[_env_idx].put(("step_and_maybe_reset", local_td, _per_env))
679677

680678
def async_step_and_maybe_reset_recv(
681679
self, min_get: int = 1, env_index: int | None = None
@@ -807,29 +805,29 @@ def _env_exec(
807805
elif msg == "batch_size":
808806
output_queue.put(env.batch_size)
809807
elif msg == "reset":
810-
data = env.reset(data.copy())
808+
# No .copy() needed: data was deserialized from the queue
809+
# and is not referenced after this call.
810+
data = env.reset(data)
811811
data.set(cls._env_idx_key, NonTensorData(i))
812812
target = per_env_reset_queue if per_env else reset_queue
813813
target.put(data)
814814
elif msg == "_reset":
815-
data = env._reset(data.copy())
815+
data = env._reset(data)
816816
data.set(cls._env_idx_key, NonTensorData(i))
817817
reset_queue.put(data)
818818
elif msg == "step_and_maybe_reset":
819-
data, data_ = env.step_and_maybe_reset(data.copy())
819+
data, data_ = env.step_and_maybe_reset(data)
820820
data.set(cls._env_idx_key, NonTensorData(i))
821821
data_.set(cls._env_idx_key, NonTensorData(i))
822-
target = (
823-
per_env_step_reset_queue if per_env else step_reset_queue
824-
)
822+
target = per_env_step_reset_queue if per_env else step_reset_queue
825823
target.put((data, data_))
826824
elif msg == "step":
827-
data = env.step(data.copy())
825+
data = env.step(data)
828826
data.set(cls._env_idx_key, NonTensorData(i))
829827
target = per_env_step_queue if per_env else step_queue
830828
target.put(data)
831829
elif msg == "_step":
832-
data = env._step(data.copy())
830+
data = env._step(data)
833831
data.set(cls._env_idx_key, NonTensorData(i))
834832
step_queue.put(data)
835833
elif msg == "shutdown":

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchrl.modules.inference_server._mp import MPTransport
88
from torchrl.modules.inference_server._ray import RayTransport
99
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
10+
from torchrl.modules.inference_server._slot import SlotTransport
1011
from torchrl.modules.inference_server._threading import ThreadingTransport
1112
from torchrl.modules.inference_server._transport import InferenceTransport
1213

@@ -17,5 +18,6 @@
1718
"MonarchTransport",
1819
"MPTransport",
1920
"RayTransport",
21+
"SlotTransport",
2022
"ThreadingTransport",
2123
]

0 commit comments

Comments
 (0)