Skip to content

Commit 122bc89

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 266e4aa commit 122bc89

File tree

18 files changed

+3064
-73
lines changed

18 files changed

+3064
-73
lines changed

benchmarks/bench_collectors.py

Lines changed: 419 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/reference/collectors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ making it easy to collect high-quality training data efficiently.
1212
TorchRL provides several collector implementations optimized for different scenarios:
1313

1414
- :class:`Collector`: Single-process collection on the training worker
15+
- :class:`AsyncBatchedCollector`: Async environments + auto-batching inference server (see :class:`AsyncBatchedCollector`)
1516
- :class:`MultiCollector`: Parallel collection across multiple workers (see below)
1617
- **Distributed collectors**: For multi-node setups using Ray, RPC, or distributed backends (see :class:`DistributedCollector` / :class:`RPCCollector`)
1718

docs/source/reference/collectors_single.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Single node data collectors
1515
BaseCollector
1616
Collector
1717
AsyncCollector
18+
AsyncBatchedCollector
1819
MultiCollector
1920
MultiSyncCollector
2021
MultiAsyncCollector
@@ -29,6 +30,49 @@ Single node data collectors
2930
- ``MultiSyncDataCollector`` → ``MultiSyncCollector``
3031
- ``MultiaSyncDataCollector`` → ``MultiAsyncCollector``
3132

33+
Using AsyncBatchedCollector
34+
---------------------------
35+
36+
The :class:`AsyncBatchedCollector` pairs an :class:`~torchrl.envs.AsyncEnvPool`
37+
with an :class:`~torchrl.modules.InferenceServer` to pipeline environment
38+
stepping and batched GPU inference. You only need to supply **env factories**
39+
and a **policy** -- all internal wiring is handled automatically:
40+
41+
.. code-block:: python
42+
43+
from torchrl.collectors import AsyncBatchedCollector
44+
from torchrl.envs import GymEnv
45+
from tensordict.nn import TensorDictModule
46+
import torch.nn as nn
47+
48+
policy = TensorDictModule(
49+
nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2)),
50+
in_keys=["observation"],
51+
out_keys=["action"],
52+
)
53+
54+
collector = AsyncBatchedCollector(
55+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
56+
policy=policy,
57+
frames_per_batch=200,
58+
total_frames=10000,
59+
max_batch_size=8,
60+
)
61+
62+
for data in collector:
63+
# data is a lazy-stacked TensorDict of collected transitions
64+
pass
65+
66+
collector.shutdown()
67+
68+
**Key advantages over** :class:`Collector`:
69+
70+
- The inference server automatically **batches policy forward passes** from
71+
all environments, maximising GPU utilisation.
72+
- Environment stepping and inference run in **overlapping fashion**, reducing
73+
idle time.
74+
- Supports ``yield_completed_trajectories=True`` for episode-level yields.
75+
3276
Using MultiCollector
3377
--------------------
3478

docs/source/reference/modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ Documentation Sections
5656
modules_mcts
5757
modules_models
5858
modules_distributions
59+
modules_inference_server
5960
modules_utils
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
.. currentmodule:: torchrl.modules.inference_server
2+
3+
Inference Server
4+
================
5+
6+
.. _ref_inference_server:
7+
8+
The inference server provides auto-batching model serving for RL actors.
9+
Multiple actors submit individual TensorDicts; the server transparently
10+
batches them, runs a single model forward pass, and routes results back.
11+
12+
Core API
13+
--------
14+
15+
.. autosummary::
16+
:toctree: generated/
17+
:template: rl_template_noinherit.rst
18+
19+
InferenceServer
20+
InferenceClient
21+
InferenceTransport
22+
23+
Transport Backends
24+
------------------
25+
26+
.. autosummary::
27+
:toctree: generated/
28+
:template: rl_template_noinherit.rst
29+
30+
ThreadingTransport
31+
MPTransport
32+
RayTransport
33+
MonarchTransport
34+
35+
Usage
36+
-----
37+
38+
The simplest setup uses :class:`ThreadingTransport` for actors that are
39+
threads in the same process:
40+
41+
.. code-block:: python
42+
43+
from tensordict.nn import TensorDictModule
44+
from torchrl.modules.inference_server import (
45+
InferenceServer,
46+
ThreadingTransport,
47+
)
48+
import torch.nn as nn
49+
import concurrent.futures
50+
51+
policy = TensorDictModule(
52+
nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
53+
in_keys=["observation"],
54+
out_keys=["action"],
55+
)
56+
57+
transport = ThreadingTransport()
58+
server = InferenceServer(policy, transport, max_batch_size=32)
59+
server.start()
60+
client = transport.client()
61+
62+
# actor threads call client(td) -- batched automatically
63+
with concurrent.futures.ThreadPoolExecutor(16) as pool:
64+
...
65+
66+
server.shutdown()
67+
68+
Weight Synchronisation
69+
^^^^^^^^^^^^^^^^^^^^^^
70+
71+
The server integrates with :class:`~torchrl.weight_update.WeightSyncScheme`
72+
to receive updated model weights from a trainer between inference batches:
73+
74+
.. code-block:: python
75+
76+
from torchrl.weight_update import SharedMemWeightSyncScheme
77+
78+
weight_sync = SharedMemWeightSyncScheme()
79+
# Initialise on the trainer (sender) side first
80+
weight_sync.init_on_sender(model=training_model, ...)
81+
82+
server = InferenceServer(
83+
model=inference_model,
84+
transport=ThreadingTransport(),
85+
weight_sync=weight_sync,
86+
)
87+
server.start()
88+
89+
# Training loop
90+
for batch in dataloader:
91+
loss = loss_fn(training_model(batch))
92+
loss.backward()
93+
optimizer.step()
94+
weight_sync.send(model=training_model) # pushed to server
95+
96+
Integration with Collectors
97+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
98+
99+
The easiest way to use the inference server with RL data collection is
100+
through :class:`~torchrl.collectors.AsyncBatchedCollector`, which
101+
creates the server, transport, and env pool automatically:
102+
103+
.. code-block:: python
104+
105+
from torchrl.collectors import AsyncBatchedCollector
106+
from torchrl.envs import GymEnv
107+
108+
collector = AsyncBatchedCollector(
109+
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
110+
policy=my_policy,
111+
frames_per_batch=200,
112+
total_frames=10_000,
113+
max_batch_size=8,
114+
)
115+
116+
for data in collector:
117+
# train on data ...
118+
pass
119+
120+
collector.shutdown()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""AsyncBatchedCollector example.
2+
3+
Demonstrates how to use :class:`~torchrl.collectors.AsyncBatchedCollector` to
4+
run many environments in parallel while automatically batching policy inference
5+
through an :class:`~torchrl.modules.InferenceServer`.
6+
7+
Architecture:
8+
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel
9+
using the chosen backend (``"threading"`` or ``"multiprocessing"``).
10+
- One lightweight coordinator thread per environment owns a slot in the pool
11+
and an inference client.
12+
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
13+
and runs a single forward pass.
14+
- There is no global synchronisation barrier -- fast envs keep stepping
15+
while slow ones wait for inference.
16+
17+
The user only supplies:
18+
- A list of environment factories
19+
- A policy (or policy factory)
20+
"""
21+
import torch.nn as nn
22+
from tensordict.nn import TensorDictModule
23+
24+
from torchrl.collectors import AsyncBatchedCollector
25+
from torchrl.envs import GymEnv
26+
27+
28+
def make_env():
29+
"""Factory that returns a CartPole environment."""
30+
return GymEnv("CartPole-v1")
31+
32+
33+
def main():
34+
num_envs = 4
35+
frames_per_batch = 200
36+
total_frames = 1_000
37+
38+
# A simple linear policy (random weights -- just for demonstration)
39+
policy = TensorDictModule(
40+
nn.Linear(4, 2), in_keys=["observation"], out_keys=["action"]
41+
)
42+
43+
collector = AsyncBatchedCollector(
44+
create_env_fn=[make_env] * num_envs,
45+
policy=policy,
46+
frames_per_batch=frames_per_batch,
47+
total_frames=total_frames,
48+
max_batch_size=num_envs,
49+
device="cpu",
50+
)
51+
52+
total_collected = 0
53+
for i, batch in enumerate(collector):
54+
n = batch.numel()
55+
total_collected += n
56+
print(f"Batch {i}: {batch.shape} ({n} frames, total={total_collected})")
57+
58+
collector.shutdown()
59+
print("Done!")
60+
61+
62+
if __name__ == "__main__":
63+
main()

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ per-file-ignores =
3636
.github/unittest/helpers/*.py: T001, T201
3737
docs/source/conf.py: T001
3838
test/test_libs.py: T001
39+
benchmarks/*.py: T001, T201
3940
torchrl/_utils.py: T002
4041

4142
exclude = venv

0 commit comments

Comments
 (0)