Skip to content

Commit 6e415a5

Browse files
vmoenscursoragent
andcommitted
[Feature] AsyncBatchedCollector: async envs + auto-batching inference
Add AsyncBatchedCollector that pairs AsyncEnvPool with InferenceServer for pipelined RL data collection. Users supply env factories and a policy; the collector handles all internal wiring (transport, server, env pool). Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 3938b33 Pull-Request: #3498
1 parent 7e68007 commit 6e415a5

File tree

7 files changed

+721
-1
lines changed

7 files changed

+721
-1
lines changed

docs/source/reference/collectors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ making it easy to collect high-quality training data efficiently.
1212
TorchRL provides several collector implementations optimized for different scenarios:
1313

1414
- :class:`Collector`: Single-process collection on the training worker
15+
- :class:`AsyncBatchedCollector`: Async environments + auto-batching inference server (see :class:`AsyncBatchedCollector`)
1516
- :class:`MultiCollector`: Parallel collection across multiple workers (see below)
1617
- **Distributed collectors**: For multi-node setups using Ray, RPC, or distributed backends (see :class:`DistributedCollector` / :class:`RPCCollector`)
1718

docs/source/reference/collectors_single.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Single node data collectors
1515
BaseCollector
1616
Collector
1717
AsyncCollector
18+
AsyncBatchedCollector
1819
MultiCollector
1920
MultiSyncCollector
2021
MultiAsyncCollector
@@ -29,6 +30,49 @@ Single node data collectors
2930
- ``MultiSyncDataCollector`` → ``MultiSyncCollector``
3031
- ``MultiaSyncDataCollector`` → ``MultiAsyncCollector``
3132

33+
Using AsyncBatchedCollector
34+
---------------------------
35+
36+
The :class:`AsyncBatchedCollector` pairs an :class:`~torchrl.envs.AsyncEnvPool`
37+
with an :class:`~torchrl.modules.InferenceServer` to pipeline environment
38+
stepping and batched GPU inference. You only need to supply **env factories**
39+
and a **policy** -- all internal wiring is handled automatically:
40+
41+
.. code-block:: python
42+
43+
from torchrl.collectors import AsyncBatchedCollector
44+
from torchrl.envs import GymEnv
45+
from tensordict.nn import TensorDictModule
46+
import torch.nn as nn
47+
48+
policy = TensorDictModule(
49+
nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2)),
50+
in_keys=["observation"],
51+
out_keys=["action"],
52+
)
53+
54+
collector = AsyncBatchedCollector(
55+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
56+
policy=policy,
57+
frames_per_batch=200,
58+
total_frames=10000,
59+
max_batch_size=8,
60+
)
61+
62+
for data in collector:
63+
# data is a lazy-stacked TensorDict of collected transitions
64+
pass
65+
66+
collector.shutdown()
67+
68+
**Key advantages over** :class:`Collector`:
69+
70+
- The inference server automatically **batches policy forward passes** from
71+
all environments, maximising GPU utilisation.
72+
- Environment stepping and inference run in **overlapping fashion**, reducing
73+
idle time.
74+
- Supports ``yield_completed_trajectories=True`` for episode-level yields.
75+
3276
Using MultiCollector
3377
--------------------
3478

docs/source/reference/modules_inference_server.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,29 @@ to receive updated model weights from a trainer between inference batches:
9292
loss.backward()
9393
optimizer.step()
9494
weight_sync.send(model=training_model) # pushed to server
95+
96+
Integration with Collectors
97+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
98+
99+
The easiest way to use the inference server with RL data collection is
100+
through :class:`~torchrl.collectors.AsyncBatchedCollector`, which
101+
creates the server, transport, and env pool automatically:
102+
103+
.. code-block:: python
104+
105+
from torchrl.collectors import AsyncBatchedCollector
106+
from torchrl.envs import GymEnv
107+
108+
collector = AsyncBatchedCollector(
109+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
110+
policy=my_policy,
111+
frames_per_batch=200,
112+
total_frames=10_000,
113+
max_batch_size=8,
114+
)
115+
116+
for data in collector:
117+
# train on data ...
118+
pass
119+
120+
collector.shutdown()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""AsyncBatchedCollector example.
2+
3+
Demonstrates how to use :class:`~torchrl.collectors.AsyncBatchedCollector` to
4+
run many environments in parallel while automatically batching policy inference
5+
through an :class:`~torchrl.modules.InferenceServer`.
6+
7+
Architecture:
8+
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel using
9+
the chosen backend (``"multiprocessing"`` by default for true parallelism,
10+
or ``"threading"``/``"asyncio"``).
11+
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
12+
and runs a single forward pass.
13+
- A lightweight coordinator thread bridges the two: when an env finishes
14+
stepping its observation is submitted to the server, and when an action is
15+
ready the env is sent back for stepping -- all without synchronisation
16+
barriers.
17+
18+
The user only supplies:
19+
- A list of environment factories
20+
- A policy (or policy factory)
21+
"""
22+
import torch.nn as nn
23+
from tensordict.nn import TensorDictModule
24+
25+
from torchrl.collectors import AsyncBatchedCollector
26+
from torchrl.envs import GymEnv
27+
28+
29+
def make_env():
30+
"""Factory that returns a CartPole environment."""
31+
return GymEnv("CartPole-v1")
32+
33+
34+
def main():
35+
num_envs = 4
36+
frames_per_batch = 200
37+
total_frames = 1_000
38+
39+
# A simple linear policy (random weights -- just for demonstration)
40+
policy = TensorDictModule(
41+
nn.Linear(4, 2), in_keys=["observation"], out_keys=["action"]
42+
)
43+
44+
collector = AsyncBatchedCollector(
45+
create_env_fn=[make_env] * num_envs,
46+
policy=policy,
47+
frames_per_batch=frames_per_batch,
48+
total_frames=total_frames,
49+
max_batch_size=num_envs,
50+
device="cpu",
51+
)
52+
53+
total_collected = 0
54+
for i, batch in enumerate(collector):
55+
n = batch.numel()
56+
total_collected += n
57+
print(f"Batch {i}: {batch.shape} ({n} frames, total={total_collected})")
58+
59+
collector.shutdown()
60+
print("Done!")
61+
62+
63+
if __name__ == "__main__":
64+
main()

test/test_inference_server.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,189 @@ def test_no_weight_sync(self):
682682
td = TensorDict({"observation": torch.randn(4)})
683683
result = client(td)
684684
assert "action" in result.keys()
685+
686+
687+
# ---------------------------------------------------------------------------
688+
# AsyncBatchedCollector tests
689+
# ---------------------------------------------------------------------------
690+
691+
from torchrl.collectors import AsyncBatchedCollector
692+
from torchrl.testing.mocking_classes import CountingEnv
693+
694+
695+
def _counting_env_factory(max_steps=5):
696+
"""Factory that returns a CountingEnv."""
697+
return CountingEnv(max_steps=max_steps)
698+
699+
700+
class _BatchCountingPolicy(TensorDictModule):
701+
"""A batch-aware policy that always outputs action=1 for CountingEnv."""
702+
703+
def __init__(self):
704+
super().__init__(
705+
module=nn.Module(), # placeholder
706+
in_keys=["observation"],
707+
out_keys=["action"],
708+
)
709+
710+
def forward(self, td: TensorDictBase) -> TensorDictBase:
711+
obs = td.get("observation")
712+
action = torch.ones_like(obs)
713+
return td.set("action", action)
714+
715+
716+
def _make_counting_policy():
717+
return _BatchCountingPolicy()
718+
719+
720+
class TestAsyncBatchedCollector:
721+
"""Tests for :class:`AsyncBatchedCollector`."""
722+
723+
def test_basic_collection(self):
724+
"""Collector yields at least frames_per_batch frames."""
725+
num_envs = 3
726+
frames_per_batch = 20
727+
total_frames = 60
728+
policy = _make_counting_policy()
729+
730+
collector = AsyncBatchedCollector(
731+
create_env_fn=[_counting_env_factory] * num_envs,
732+
policy=policy,
733+
frames_per_batch=frames_per_batch,
734+
total_frames=total_frames,
735+
max_batch_size=num_envs,
736+
env_backend="threading",
737+
)
738+
total_collected = 0
739+
for batch in collector:
740+
assert batch is not None
741+
total_collected += batch.numel()
742+
collector.shutdown()
743+
assert total_collected >= total_frames
744+
745+
def test_policy_factory(self):
746+
"""policy_factory is called to create the policy."""
747+
num_envs = 2
748+
collector = AsyncBatchedCollector(
749+
create_env_fn=[_counting_env_factory] * num_envs,
750+
policy_factory=_make_counting_policy,
751+
frames_per_batch=10,
752+
total_frames=20,
753+
max_batch_size=num_envs,
754+
env_backend="threading",
755+
)
756+
total_collected = 0
757+
for batch in collector:
758+
total_collected += batch.numel()
759+
collector.shutdown()
760+
assert total_collected >= 20
761+
762+
def test_policy_xor_factory(self):
763+
"""Providing both policy and policy_factory raises."""
764+
policy = _make_counting_policy()
765+
with pytest.raises(TypeError, match="mutually exclusive"):
766+
AsyncBatchedCollector(
767+
create_env_fn=[_counting_env_factory],
768+
policy=policy,
769+
policy_factory=_make_counting_policy,
770+
frames_per_batch=10,
771+
)
772+
773+
def test_neither_policy_nor_factory(self):
774+
"""Providing neither raises."""
775+
with pytest.raises(TypeError, match="must be provided"):
776+
AsyncBatchedCollector(
777+
create_env_fn=[_counting_env_factory],
778+
frames_per_batch=10,
779+
)
780+
781+
def test_yield_completed_trajectories(self):
782+
"""With yield_completed_trajectories, collector yields done trajectories."""
783+
num_envs = 3
784+
max_steps = 5
785+
policy = _make_counting_policy()
786+
787+
collector = AsyncBatchedCollector(
788+
create_env_fn=[lambda: CountingEnv(max_steps=max_steps)] * num_envs,
789+
policy=policy,
790+
frames_per_batch=1,
791+
total_frames=30,
792+
yield_completed_trajectories=True,
793+
max_batch_size=num_envs,
794+
env_backend="threading",
795+
)
796+
count = 0
797+
for batch in collector:
798+
assert batch is not None
799+
# Each trajectory should end with done=True
800+
count += batch.numel()
801+
collector.shutdown()
802+
assert count >= 30
803+
804+
def test_shutdown_idempotent(self):
805+
"""Calling shutdown twice should not raise."""
806+
policy = _make_counting_policy()
807+
collector = AsyncBatchedCollector(
808+
create_env_fn=[_counting_env_factory] * 2,
809+
policy=policy,
810+
frames_per_batch=10,
811+
total_frames=10,
812+
env_backend="threading",
813+
)
814+
# Consume one batch to start
815+
for _batch in collector:
816+
break
817+
collector.shutdown()
818+
collector.shutdown() # should not raise
819+
820+
def test_endless_collector(self):
821+
"""total_frames=-1 creates an endless collector; verify manual break works."""
822+
policy = _make_counting_policy()
823+
collector = AsyncBatchedCollector(
824+
create_env_fn=[_counting_env_factory] * 2,
825+
policy=policy,
826+
frames_per_batch=10,
827+
total_frames=-1,
828+
env_backend="threading",
829+
)
830+
collected = 0
831+
for batch in collector:
832+
collected += batch.numel()
833+
if collected >= 50:
834+
break
835+
collector.shutdown()
836+
assert collected >= 50
837+
838+
def test_num_envs(self):
839+
"""The collector knows the number of environments."""
840+
policy = _make_counting_policy()
841+
collector = AsyncBatchedCollector(
842+
create_env_fn=[_counting_env_factory] * 2,
843+
policy=policy,
844+
frames_per_batch=10,
845+
total_frames=10,
846+
)
847+
assert collector._num_envs == 2
848+
collector.shutdown()
849+
850+
def test_postproc(self):
851+
"""Post-processing callable is applied to every batch."""
852+
policy = _make_counting_policy()
853+
called = {"count": 0}
854+
855+
def postproc(td):
856+
called["count"] += 1
857+
return td
858+
859+
collector = AsyncBatchedCollector(
860+
create_env_fn=[_counting_env_factory] * 2,
861+
policy=policy,
862+
frames_per_batch=10,
863+
total_frames=20,
864+
postproc=postproc,
865+
env_backend="threading",
866+
)
867+
for _ in collector:
868+
pass
869+
collector.shutdown()
870+
assert called["count"] >= 1

torchrl/collectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from torchrl.modules.tensordict_module.exploration import RandomPolicy
88

9+
from ._async_batched import AsyncBatchedCollector
10+
911
from ._base import BaseCollector, DataCollectorBase, ProfileConfig
1012

1113
from ._multi_async import MultiAsyncCollector, MultiaSyncDataCollector
1214
from ._multi_base import MultiCollector, MultiCollector as _MultiDataCollector
1315
from ._multi_sync import MultiSyncCollector, MultiSyncDataCollector
1416
from ._single import Collector, SyncDataCollector
15-
1617
from ._single_async import AsyncCollector, aSyncDataCollector
1718
from .weight_update import (
1819
MultiProcessedWeightUpdater,
@@ -29,6 +30,7 @@
2930
"AsyncCollector",
3031
"MultiCollector",
3132
"MultiSyncCollector",
33+
"AsyncBatchedCollector",
3234
"MultiAsyncCollector",
3335
"ProfileConfig",
3436
# Legacy names (backward-compatible aliases)

0 commit comments

Comments
 (0)