Skip to content

Commit fe0bcfe

Browse files
vmoenscursoragent
andcommitted
[Feature] AsyncBatchedCollector: async envs + auto-batching inference (#3498)
Add AsyncBatchedCollector that pairs AsyncEnvPool with InferenceServer for pipelined RL data collection. Users supply env factories and a policy; the collector handles all internal wiring (transport, server, env pool). Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 525120c Pull-Request: #3498
1 parent 509448b commit fe0bcfe

File tree

7 files changed

+721
-1
lines changed

7 files changed

+721
-1
lines changed

docs/source/reference/collectors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ making it easy to collect high-quality training data efficiently.
1212
TorchRL provides several collector implementations optimized for different scenarios:
1313

1414
- :class:`Collector`: Single-process collection on the training worker
15+
- :class:`AsyncBatchedCollector`: Async environments + auto-batching inference server (see :class:`AsyncBatchedCollector`)
1516
- :class:`MultiCollector`: Parallel collection across multiple workers (see below)
1617
- **Distributed collectors**: For multi-node setups using Ray, RPC, or distributed backends (see :class:`DistributedCollector` / :class:`RPCCollector`)
1718

docs/source/reference/collectors_single.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Single node data collectors
1515
BaseCollector
1616
Collector
1717
AsyncCollector
18+
AsyncBatchedCollector
1819
MultiCollector
1920
MultiSyncCollector
2021
MultiAsyncCollector
@@ -29,6 +30,49 @@ Single node data collectors
2930
- ``MultiSyncDataCollector`` → ``MultiSyncCollector``
3031
- ``MultiaSyncDataCollector`` → ``MultiAsyncCollector``
3132

33+
Using AsyncBatchedCollector
34+
---------------------------
35+
36+
The :class:`AsyncBatchedCollector` pairs an :class:`~torchrl.envs.AsyncEnvPool`
37+
with an :class:`~torchrl.modules.InferenceServer` to pipeline environment
38+
stepping and batched GPU inference. You only need to supply **env factories**
39+
and a **policy** -- all internal wiring is handled automatically:
40+
41+
.. code-block:: python
42+
43+
from torchrl.collectors import AsyncBatchedCollector
44+
from torchrl.envs import GymEnv
45+
from tensordict.nn import TensorDictModule
46+
import torch.nn as nn
47+
48+
policy = TensorDictModule(
49+
nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2)),
50+
in_keys=["observation"],
51+
out_keys=["action"],
52+
)
53+
54+
collector = AsyncBatchedCollector(
55+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
56+
policy=policy,
57+
frames_per_batch=200,
58+
total_frames=10000,
59+
max_batch_size=8,
60+
)
61+
62+
for data in collector:
63+
# data is a lazy-stacked TensorDict of collected transitions
64+
pass
65+
66+
collector.shutdown()
67+
68+
**Key advantages over** :class:`Collector`:
69+
70+
- The inference server automatically **batches policy forward passes** from
71+
all environments, maximising GPU utilisation.
72+
- Environment stepping and inference run in **overlapping fashion**, reducing
73+
idle time.
74+
- Supports ``yield_completed_trajectories=True`` for episode-level yields.
75+
3276
Using MultiCollector
3377
--------------------
3478

docs/source/reference/modules_inference_server.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,29 @@ to receive updated model weights from a trainer between inference batches:
9292
loss.backward()
9393
optimizer.step()
9494
weight_sync.send(model=training_model) # pushed to server
95+
96+
Integration with Collectors
97+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
98+
99+
The easiest way to use the inference server with RL data collection is
100+
through :class:`~torchrl.collectors.AsyncBatchedCollector`, which
101+
creates the server, transport, and env pool automatically:
102+
103+
.. code-block:: python
104+
105+
from torchrl.collectors import AsyncBatchedCollector
106+
from torchrl.envs import GymEnv
107+
108+
collector = AsyncBatchedCollector(
109+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
110+
policy=my_policy,
111+
frames_per_batch=200,
112+
total_frames=10_000,
113+
max_batch_size=8,
114+
)
115+
116+
for data in collector:
117+
# train on data ...
118+
pass
119+
120+
collector.shutdown()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""AsyncBatchedCollector example.
2+
3+
Demonstrates how to use :class:`~torchrl.collectors.AsyncBatchedCollector` to
4+
run many environments in parallel while automatically batching policy inference
5+
through an :class:`~torchrl.modules.InferenceServer`.
6+
7+
Architecture:
8+
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel using
9+
the chosen backend (``"multiprocessing"`` by default for true parallelism,
10+
or ``"threading"``/``"asyncio"``).
11+
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
12+
and runs a single forward pass.
13+
- A lightweight coordinator thread bridges the two: when an env finishes
14+
stepping its observation is submitted to the server, and when an action is
15+
ready the env is sent back for stepping -- all without synchronisation
16+
barriers.
17+
18+
The user only supplies:
19+
- A list of environment factories
20+
- A policy (or policy factory)
21+
"""
22+
import torch.nn as nn
23+
from tensordict.nn import TensorDictModule
24+
25+
from torchrl.collectors import AsyncBatchedCollector
26+
from torchrl.envs import GymEnv
27+
28+
29+
def make_env():
30+
"""Factory that returns a CartPole environment."""
31+
return GymEnv("CartPole-v1")
32+
33+
34+
def main():
35+
num_envs = 4
36+
frames_per_batch = 200
37+
total_frames = 1_000
38+
39+
# A simple linear policy (random weights -- just for demonstration)
40+
policy = TensorDictModule(
41+
nn.Linear(4, 2), in_keys=["observation"], out_keys=["action"]
42+
)
43+
44+
collector = AsyncBatchedCollector(
45+
create_env_fn=[make_env] * num_envs,
46+
policy=policy,
47+
frames_per_batch=frames_per_batch,
48+
total_frames=total_frames,
49+
max_batch_size=num_envs,
50+
device="cpu",
51+
)
52+
53+
total_collected = 0
54+
for i, batch in enumerate(collector):
55+
n = batch.numel()
56+
total_collected += n
57+
print(f"Batch {i}: {batch.shape} ({n} frames, total={total_collected})")
58+
59+
collector.shutdown()
60+
print("Done!")
61+
62+
63+
if __name__ == "__main__":
64+
main()

test/test_inference_server.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,189 @@ def test_no_weight_sync(self):
677677
td = TensorDict({"observation": torch.randn(4)})
678678
result = client(td)
679679
assert "action" in result.keys()
680+
681+
682+
# ---------------------------------------------------------------------------
683+
# AsyncBatchedCollector tests
684+
# ---------------------------------------------------------------------------
685+
686+
from torchrl.collectors import AsyncBatchedCollector
687+
from torchrl.testing.mocking_classes import CountingEnv
688+
689+
690+
def _counting_env_factory(max_steps=5):
691+
"""Factory that returns a CountingEnv."""
692+
return CountingEnv(max_steps=max_steps)
693+
694+
695+
class _BatchCountingPolicy(TensorDictModule):
696+
"""A batch-aware policy that always outputs action=1 for CountingEnv."""
697+
698+
def __init__(self):
699+
super().__init__(
700+
module=nn.Module(), # placeholder
701+
in_keys=["observation"],
702+
out_keys=["action"],
703+
)
704+
705+
def forward(self, td: TensorDictBase) -> TensorDictBase:
706+
obs = td.get("observation")
707+
action = torch.ones_like(obs)
708+
return td.set("action", action)
709+
710+
711+
def _make_counting_policy():
712+
return _BatchCountingPolicy()
713+
714+
715+
class TestAsyncBatchedCollector:
716+
"""Tests for :class:`AsyncBatchedCollector`."""
717+
718+
def test_basic_collection(self):
719+
"""Collector yields at least frames_per_batch frames."""
720+
num_envs = 3
721+
frames_per_batch = 20
722+
total_frames = 60
723+
policy = _make_counting_policy()
724+
725+
collector = AsyncBatchedCollector(
726+
create_env_fn=[_counting_env_factory] * num_envs,
727+
policy=policy,
728+
frames_per_batch=frames_per_batch,
729+
total_frames=total_frames,
730+
max_batch_size=num_envs,
731+
env_backend="threading",
732+
)
733+
total_collected = 0
734+
for batch in collector:
735+
assert batch is not None
736+
total_collected += batch.numel()
737+
collector.shutdown()
738+
assert total_collected >= total_frames
739+
740+
def test_policy_factory(self):
741+
"""policy_factory is called to create the policy."""
742+
num_envs = 2
743+
collector = AsyncBatchedCollector(
744+
create_env_fn=[_counting_env_factory] * num_envs,
745+
policy_factory=_make_counting_policy,
746+
frames_per_batch=10,
747+
total_frames=20,
748+
max_batch_size=num_envs,
749+
env_backend="threading",
750+
)
751+
total_collected = 0
752+
for batch in collector:
753+
total_collected += batch.numel()
754+
collector.shutdown()
755+
assert total_collected >= 20
756+
757+
def test_policy_xor_factory(self):
758+
"""Providing both policy and policy_factory raises."""
759+
policy = _make_counting_policy()
760+
with pytest.raises(TypeError, match="mutually exclusive"):
761+
AsyncBatchedCollector(
762+
create_env_fn=[_counting_env_factory],
763+
policy=policy,
764+
policy_factory=_make_counting_policy,
765+
frames_per_batch=10,
766+
)
767+
768+
def test_neither_policy_nor_factory(self):
769+
"""Providing neither raises."""
770+
with pytest.raises(TypeError, match="must be provided"):
771+
AsyncBatchedCollector(
772+
create_env_fn=[_counting_env_factory],
773+
frames_per_batch=10,
774+
)
775+
776+
def test_yield_completed_trajectories(self):
777+
"""With yield_completed_trajectories, collector yields done trajectories."""
778+
num_envs = 3
779+
max_steps = 5
780+
policy = _make_counting_policy()
781+
782+
collector = AsyncBatchedCollector(
783+
create_env_fn=[lambda: CountingEnv(max_steps=max_steps)] * num_envs,
784+
policy=policy,
785+
frames_per_batch=1,
786+
total_frames=30,
787+
yield_completed_trajectories=True,
788+
max_batch_size=num_envs,
789+
env_backend="threading",
790+
)
791+
count = 0
792+
for batch in collector:
793+
assert batch is not None
794+
# Each trajectory should end with done=True
795+
count += batch.numel()
796+
collector.shutdown()
797+
assert count >= 30
798+
799+
def test_shutdown_idempotent(self):
800+
"""Calling shutdown twice should not raise."""
801+
policy = _make_counting_policy()
802+
collector = AsyncBatchedCollector(
803+
create_env_fn=[_counting_env_factory] * 2,
804+
policy=policy,
805+
frames_per_batch=10,
806+
total_frames=10,
807+
env_backend="threading",
808+
)
809+
# Consume one batch to start
810+
for _batch in collector:
811+
break
812+
collector.shutdown()
813+
collector.shutdown() # should not raise
814+
815+
def test_endless_collector(self):
816+
"""total_frames=-1 creates an endless collector; verify manual break works."""
817+
policy = _make_counting_policy()
818+
collector = AsyncBatchedCollector(
819+
create_env_fn=[_counting_env_factory] * 2,
820+
policy=policy,
821+
frames_per_batch=10,
822+
total_frames=-1,
823+
env_backend="threading",
824+
)
825+
collected = 0
826+
for batch in collector:
827+
collected += batch.numel()
828+
if collected >= 50:
829+
break
830+
collector.shutdown()
831+
assert collected >= 50
832+
833+
def test_num_envs(self):
834+
"""The collector knows the number of environments."""
835+
policy = _make_counting_policy()
836+
collector = AsyncBatchedCollector(
837+
create_env_fn=[_counting_env_factory] * 2,
838+
policy=policy,
839+
frames_per_batch=10,
840+
total_frames=10,
841+
)
842+
assert collector._num_envs == 2
843+
collector.shutdown()
844+
845+
def test_postproc(self):
846+
"""Post-processing callable is applied to every batch."""
847+
policy = _make_counting_policy()
848+
called = {"count": 0}
849+
850+
def postproc(td):
851+
called["count"] += 1
852+
return td
853+
854+
collector = AsyncBatchedCollector(
855+
create_env_fn=[_counting_env_factory] * 2,
856+
policy=policy,
857+
frames_per_batch=10,
858+
total_frames=20,
859+
postproc=postproc,
860+
env_backend="threading",
861+
)
862+
for _ in collector:
863+
pass
864+
collector.shutdown()
865+
assert called["count"] >= 1

torchrl/collectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from torchrl.modules.tensordict_module.exploration import RandomPolicy
88

9+
from ._async_batched import AsyncBatchedCollector
10+
911
from ._base import BaseCollector, DataCollectorBase, ProfileConfig
1012

1113
from ._multi_async import MultiAsyncCollector, MultiaSyncDataCollector
1214
from ._multi_base import MultiCollector, MultiCollector as _MultiDataCollector
1315
from ._multi_sync import MultiSyncCollector, MultiSyncDataCollector
1416
from ._single import Collector, SyncDataCollector
15-
1617
from ._single_async import AsyncCollector, aSyncDataCollector
1718
from .weight_update import (
1819
MultiProcessedWeightUpdater,
@@ -29,6 +30,7 @@
2930
"AsyncCollector",
3031
"MultiCollector",
3132
"MultiSyncCollector",
33+
"AsyncBatchedCollector",
3234
"MultiAsyncCollector",
3335
"ProfileConfig",
3436
# Legacy names (backward-compatible aliases)

0 commit comments

Comments
 (0)